inference_equation.py 3.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132
  1. import os
  2. import re
  3. import sys
  4. from glob import glob
  5. import cv2
  6. import numpy as np
  7. import tensorflow as tf
  8. sys.path.append(os.path.dirname(os.path.abspath(__file__)) + "/../")
  9. from model import crnn_ctc_equation, ctc_decode, crnn_ctc_equation_large
  10. from utils import pil_resize, add_contrast
  11. package_dir = os.path.abspath(os.path.dirname(__file__))
  12. image_shape = (32, 192, 1)
  13. model_path = package_dir + "/models/e55-loss0.14-equation.h5"
  14. with open(package_dir + "/equation.txt", 'r', encoding='utf-8') as f:
  15. char_list = f.readlines()
  16. char_str = "".join(char_list)
  17. char_str = re.sub("\n", "", char_str)
  18. def recognize(image_np, model=None, sess=None):
  19. if sess is None:
  20. sess = tf.compat.v1.Session(graph=tf.Graph())
  21. if model is None:
  22. with sess.as_default():
  23. with sess.graph.as_default():
  24. model = crnn_ctc_equation_large(input_shape=image_shape, class_num=35+2, is_train=False)
  25. model.load_weights(model_path)
  26. img = image_np
  27. img = pil_resize(img, image_shape[0], image_shape[1])
  28. if img.shape[2] == 3:
  29. img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
  30. img = np.expand_dims(img, axis=-1)
  31. img = add_contrast(img)
  32. # cv2.imshow("contrast", img)
  33. X = []
  34. img = img / 255.
  35. X.append(img)
  36. X = np.array(X)
  37. with sess.as_default():
  38. with sess.graph.as_default():
  39. pred = ctc_decode(X, model)
  40. result_list = []
  41. for index in pred:
  42. index = int(index-1)
  43. if index < 0:
  44. continue
  45. result_list.append(char_str[index])
  46. # print(char_str[index])
  47. # cv2.waitKey(0)
  48. print("cer result", result_list)
  49. return calculate(result_list)
  50. def calculate(_list):
  51. char_dict = {
  52. "1": 1,
  53. "2": 2,
  54. "3": 3,
  55. "4": 4,
  56. "5": 5,
  57. "6": 6,
  58. "7": 7,
  59. "8": 8,
  60. "9": 9,
  61. "0": 0,
  62. "一": 1,
  63. "二": 2,
  64. "三": 3,
  65. "四": 4,
  66. "五": 5,
  67. "六": 6,
  68. "七": 7,
  69. "八": 8,
  70. "九": 9,
  71. "零": 0,
  72. "加": "加",
  73. "减": "减",
  74. "乘": "乘",
  75. "除": "除",
  76. "+": "加",
  77. "-": "减",
  78. "*": "乘",
  79. "×": "乘",
  80. "/": "除",
  81. "÷": "除",
  82. "=": "",
  83. "?": "",
  84. "上": "",
  85. "去": "",
  86. "以": "",
  87. }
  88. equation_str = ""
  89. for c in _list:
  90. equation_str += str(char_dict.get(c))
  91. op = re.findall("加|减|乘|除", equation_str)
  92. op = list(set(op))
  93. if len(op) != 1:
  94. return None
  95. nums = re.split("加|减|乘|除", equation_str)
  96. if len(nums) != 2:
  97. return None
  98. try:
  99. num1 = int(nums[0])
  100. num2 = int(nums[1])
  101. except:
  102. print("非数字!")
  103. return None
  104. op = op[0]
  105. if op == "加":
  106. result = num1 + num2
  107. elif op == '减':
  108. result = num1 - num2
  109. elif op == '乘':
  110. result = num1 * num2
  111. elif op == '除':
  112. result = int(num1 / max(num2, 1))
  113. return result
  114. if __name__ == "__main__":
  115. # _path = "../data/test/char_9.jpg"
  116. # _path = "../data/equation/38376_减_1_问_加_4_除.jpg"
  117. _paths = glob("../data/test/FileInfo1021/*")
  118. for _path in _paths:
  119. recognize(_path)