import os import re import sys from glob import glob import cv2 import numpy as np import tensorflow as tf sys.path.append(os.path.dirname(os.path.abspath(__file__)) + "/../") from model import crnn_ctc_equation, ctc_decode, crnn_ctc_equation_large from utils import pil_resize, add_contrast package_dir = os.path.abspath(os.path.dirname(__file__)) image_shape = (32, 192, 1) model_path = package_dir + "/models/e55-loss0.14-equation.h5" with open(package_dir + "/equation.txt", 'r', encoding='utf-8') as f: char_list = f.readlines() char_str = "".join(char_list) char_str = re.sub("\n", "", char_str) def recognize(image_np, model=None, sess=None): if sess is None: sess = tf.compat.v1.Session(graph=tf.Graph()) if model is None: with sess.as_default(): with sess.graph.as_default(): model = crnn_ctc_equation_large(input_shape=image_shape, class_num=35+2, is_train=False) model.load_weights(model_path) img = image_np img = pil_resize(img, image_shape[0], image_shape[1]) if img.shape[2] == 3: img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) img = np.expand_dims(img, axis=-1) img = add_contrast(img) # cv2.imshow("contrast", img) X = [] img = img / 255. X.append(img) X = np.array(X) with sess.as_default(): with sess.graph.as_default(): pred = ctc_decode(X, model) result_list = [] for index in pred: index = int(index-1) if index < 0: continue result_list.append(char_str[index]) # print(char_str[index]) # cv2.waitKey(0) print("cer result", result_list) return calculate(result_list) def calculate(_list): char_dict = { "1": 1, "2": 2, "3": 3, "4": 4, "5": 5, "6": 6, "7": 7, "8": 8, "9": 9, "0": 0, "一": 1, "二": 2, "三": 3, "四": 4, "五": 5, "六": 6, "七": 7, "八": 8, "九": 9, "零": 0, "加": "加", "减": "减", "乘": "乘", "除": "除", "+": "加", "-": "减", "*": "乘", "×": "乘", "/": "除", "÷": "除", "=": "", "?": "", "上": "", "去": "", "以": "", } equation_str = "" for c in _list: equation_str += str(char_dict.get(c)) op = re.findall("加|减|乘|除", equation_str) op = list(set(op)) if len(op) != 1: return None nums = re.split("加|减|乘|除", equation_str) if len(nums) != 2: return None try: num1 = int(nums[0]) num2 = int(nums[1]) except: print("非数字!") return None op = op[0] if op == "加": result = num1 + num2 elif op == '减': result = num1 - num2 elif op == '乘': result = num1 * num2 elif op == '除': result = int(num1 / max(num2, 1)) return result if __name__ == "__main__": # _path = "../data/test/char_9.jpg" # _path = "../data/equation/38376_减_1_问_加_4_除.jpg" _paths = glob("../data/test/FileInfo1021/*") for _path in _paths: recognize(_path)