inference_equation.py 5.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221
  1. import os
  2. import re
  3. from glob import glob
  4. import cv2
  5. import numpy as np
  6. from click_captcha.inference_equation_denoise import denoise
  7. from click_captcha.model import crnn_ctc_equation, ctc_decode, crnn_ctc_equation_large, crnn_ctc_equation_less, \
  8. crnn_ctc_equation_loss
  9. from click_captcha.pre_process import eight_neighbour, connected_component, add_contrast
  10. from click_captcha.utils import pil_resize
  11. image_shape = (32, 192, 1)
  12. weights_path = "./models/e55-loss0.14-equation.h5"
  13. project_dir = os.path.dirname(os.path.abspath(__file__)) + "/../"
  14. model = crnn_ctc_equation_loss(input_shape=image_shape, class_num=35+2, is_train=False)
  15. model.load_weights(weights_path)
  16. def recognize(image_path):
  17. X = []
  18. img = cv2.imread(image_path)
  19. img = pil_resize(img, image_shape[0], image_shape[1])
  20. # cv2.imshow("img", img)
  21. img = denoise(img)
  22. # cv2.imshow("denoise", img)
  23. img = add_contrast(img)
  24. if img.shape[2] == 3:
  25. img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
  26. img = np.expand_dims(img, axis=-1)
  27. # _, img = cv2.threshold(img, 110, 255, cv2.THRESH_BINARY)
  28. # cv2.imshow("contrast", img)
  29. # gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
  30. # gray = eight_neighbour(gray, 4)
  31. # cv2.imshow("eight_neighbour", gray)
  32. # img2 = connected_component(gray)
  33. # cv2.imshow("connected_component", img2)
  34. # img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
  35. # img = np.expand_dims(img, axis=-1)
  36. img = img / 255.
  37. X.append(img)
  38. X = np.array(X)
  39. pred = ctc_decode(X, model)
  40. print(image_path)
  41. pred = pred[0][0]
  42. with open(project_dir + "data/equation.txt", 'r') as f:
  43. char_list = f.readlines()
  44. char_str = "".join(char_list)
  45. char_str = re.sub("\n", "", char_str)
  46. result_list = []
  47. for index in pred:
  48. index = int(index-1)
  49. # index = int(index)
  50. if index < 0:
  51. continue
  52. # print(char_str[index])
  53. result_list.append(char_str[index])
  54. print(result_list)
  55. label = image_path.split("_")[-1].split(".")[0]
  56. pred = calculate(result_list)
  57. print("计算结果:", pred, label)
  58. # cv2.waitKey(0)
  59. if pred is None:
  60. return 0
  61. if int(label) == int(pred):
  62. return 1
  63. else:
  64. return 0
  65. def calculate1(_list):
  66. char_dict = {
  67. "1": 1,
  68. "2": 2,
  69. "3": 3,
  70. "4": 4,
  71. "5": 5,
  72. "6": 6,
  73. "7": 7,
  74. "8": 8,
  75. "9": 9,
  76. "0": 0,
  77. "一": 1,
  78. "二": 2,
  79. "三": 3,
  80. "四": 4,
  81. "五": 5,
  82. "六": 6,
  83. "七": 7,
  84. "八": 8,
  85. "九": 9,
  86. "零": 0,
  87. "加": "加",
  88. "减": "减",
  89. "乘": "乘",
  90. "除": "除",
  91. "+": "加",
  92. "-": "减",
  93. "*": "乘",
  94. "×": "乘",
  95. "/": "除",
  96. "÷": "除",
  97. "=": "",
  98. "?": "",
  99. "上": "",
  100. "去": "",
  101. "以": "",
  102. }
  103. equation_str = ""
  104. for c in _list:
  105. equation_str += str(char_dict.get(c))
  106. op = re.findall("加|减|乘|除", equation_str)
  107. op = list(set(op))
  108. if len(op) != 1:
  109. return None
  110. nums = re.split("加|减|乘|除", equation_str)
  111. if len(nums) != 2:
  112. return None
  113. try:
  114. num1 = int(nums[0])
  115. num2 = int(nums[1])
  116. except:
  117. print("非数字!")
  118. return None
  119. op = op[0]
  120. if op == "加":
  121. result = num1 + num2
  122. elif op == '减':
  123. result = num1 - num2
  124. elif op == '乘':
  125. result = num1 * num2
  126. elif op == '除':
  127. result = int(num1 / max(num2, 1))
  128. return result
  129. def calculate(_list):
  130. char_dict = {
  131. "1": 1,
  132. "2": 2,
  133. "3": 3,
  134. "4": 4,
  135. "5": 5,
  136. "6": 6,
  137. "7": 7,
  138. "8": 8,
  139. "9": 9,
  140. "0": 0,
  141. "一": 1,
  142. "二": 2,
  143. "三": 3,
  144. "四": 4,
  145. "五": 5,
  146. "六": 6,
  147. "七": 7,
  148. "八": 8,
  149. "九": 9,
  150. "零": 0,
  151. "加": "加",
  152. "减": "减",
  153. "乘": "乘",
  154. "除": "除",
  155. "+": "加",
  156. "-": "减",
  157. "*": "乘",
  158. "×": "乘",
  159. "/": "除",
  160. "÷": "除",
  161. "=": "",
  162. "?": "",
  163. "上": "",
  164. "去": "",
  165. "以": "",
  166. }
  167. equation_str = ""
  168. for c in _list:
  169. equation_str += str(char_dict.get(c))
  170. op = re.findall("加|减|乘|除", equation_str)
  171. op = list(set(op))
  172. if len(op) != 1:
  173. return None
  174. nums = re.split("加|减|乘|除", equation_str)
  175. if len(nums) != 2:
  176. return None
  177. try:
  178. num1 = int(nums[0])
  179. num2 = int(nums[1])
  180. except:
  181. print("非数字!")
  182. return None
  183. op = op[0]
  184. if op == "加":
  185. result = num1 + num2
  186. elif op == '减':
  187. result = num1 - num2
  188. elif op == '乘':
  189. result = num1 * num2
  190. elif op == '除':
  191. result = int(num1 / max(num2, 1))
  192. return result
  193. if __name__ == "__main__":
  194. # _path = "../data/test/char_9.jpg"
  195. # _path = "../data/equation/38376_减_1_问_加_4_除.jpg"
  196. _paths = glob("../data/test/FileInfo1021/*")
  197. right_num = 0
  198. for _path in _paths:
  199. r = recognize(_path)
  200. if r:
  201. right_num += 1
  202. print("准确率:", right_num / len(_paths), right_num, len(_paths))