inference_equation_torch.py 3.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134
  1. import os
  2. import random
  3. import re
  4. import sys
  5. from glob import glob
  6. import cv2
  7. import numpy as np
  8. import torch
  9. from torch.utils.data import DataLoader
  10. sys.path.append(os.path.dirname(os.path.abspath(__file__)) + "/../")
  11. from model_torch import crnn_ctc_equation_torch6
  12. from pre_process_torch import EquationDataset, py_ctc_decode
  13. package_dir = os.path.abspath(os.path.dirname(__file__))
  14. model_path = package_dir + "/models/equation6_model_acc-0.853.pth"
  15. random.seed(42)
  16. device = torch.device("cpu")
  17. image_shape = (32, 192, 3)
  18. project_root = os.path.dirname(os.path.abspath(__file__)) + "/../"
  19. class_num = 35 + 1
  20. batch_size = 1
  21. input_len = 12
  22. label_len = 8
  23. with open(package_dir + "/equation_torch.txt", 'r', encoding='utf-8') as f:
  24. char_list = f.readlines()
  25. char_str = "".join(char_list)
  26. char_str = re.sub("\n", "", char_str)
  27. def recognize(image_np, model=None):
  28. if model is None:
  29. model = crnn_ctc_equation_torch6(class_num)
  30. model.load_state_dict(torch.load(model_path, map_location=torch.device(device)))
  31. model.eval()
  32. # print('type(image_np)', type(image_np))
  33. dataset = EquationDataset([image_np], image_shape, input_len, label_len, channel=image_shape[-1], mode=1)
  34. data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=False)
  35. calculate_result = None
  36. with torch.no_grad():
  37. for data, targets, _, _ in data_loader:
  38. data = data.to(device)
  39. data = data.to(torch.float32)
  40. outputs = model(data)
  41. result_list = py_ctc_decode(outputs)[0]
  42. for result in result_list:
  43. cal = calculate(result)
  44. if cal:
  45. calculate_result = cal
  46. break
  47. print("cer result", result, calculate_result)
  48. return calculate_result
  49. def calculate(_list):
  50. char_dict = {
  51. "1": 1,
  52. "2": 2,
  53. "3": 3,
  54. "4": 4,
  55. "5": 5,
  56. "6": 6,
  57. "7": 7,
  58. "8": 8,
  59. "9": 9,
  60. "0": 0,
  61. "一": 1,
  62. "二": 2,
  63. "三": 3,
  64. "四": 4,
  65. "五": 5,
  66. "六": 6,
  67. "七": 7,
  68. "八": 8,
  69. "九": 9,
  70. "零": 0,
  71. "加": "加",
  72. "减": "减",
  73. "乘": "乘",
  74. "除": "除",
  75. "+": "加",
  76. "-": "减",
  77. "*": "乘",
  78. "×": "乘",
  79. "/": "除",
  80. "÷": "除",
  81. "=": "",
  82. "?": "",
  83. "上": "",
  84. "去": "",
  85. "以": "",
  86. }
  87. equation_str = ""
  88. for c in _list:
  89. equation_str += str(char_dict.get(c))
  90. op = re.findall("加|减|乘|除", equation_str)
  91. op = list(set(op))
  92. if len(op) != 1:
  93. return None
  94. nums = re.split("加|减|乘|除", equation_str)
  95. if len(nums) != 2:
  96. return None
  97. try:
  98. num1 = int(nums[0])
  99. num2 = int(nums[1])
  100. except:
  101. print("非数字!")
  102. return None
  103. op = op[0]
  104. if op == "加":
  105. result = num1 + num2
  106. elif op == '减':
  107. result = num1 - num2
  108. elif op == '乘':
  109. result = num1 * num2
  110. elif op == '除':
  111. result = int(num1 / max(num2, 1))
  112. return result
  113. if __name__ == "__main__":
  114. # _path = "../data/test/char_9.jpg"
  115. # _path = "../data/equation/38376_减_1_问_加_4_除.jpg"
  116. _paths = glob("./*.jpg")
  117. for _path in _paths:
  118. recognize(cv2.imread(_path))