import os import random import re import sys from glob import glob import cv2 import numpy as np import torch from torch.utils.data import DataLoader sys.path.append(os.path.dirname(os.path.abspath(__file__)) + "/../") from model_torch import crnn_ctc_equation_torch6 from pre_process_torch import EquationDataset, py_ctc_decode package_dir = os.path.abspath(os.path.dirname(__file__)) model_path = package_dir + "/models/equation6_model_acc-0.853.pth" random.seed(42) device = torch.device("cpu") image_shape = (32, 192, 3) project_root = os.path.dirname(os.path.abspath(__file__)) + "/../" class_num = 35 + 1 batch_size = 1 input_len = 12 label_len = 8 with open(package_dir + "/equation_torch.txt", 'r', encoding='utf-8') as f: char_list = f.readlines() char_str = "".join(char_list) char_str = re.sub("\n", "", char_str) def recognize(image_np, model=None): if model is None: model = crnn_ctc_equation_torch6(class_num) model.load_state_dict(torch.load(model_path, map_location=torch.device(device))) model.eval() # print('type(image_np)', type(image_np)) dataset = EquationDataset([image_np], image_shape, input_len, label_len, channel=image_shape[-1], mode=1) data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=False) calculate_result = None with torch.no_grad(): for data, targets, _, _ in data_loader: data = data.to(device) data = data.to(torch.float32) outputs = model(data) result_list = py_ctc_decode(outputs)[0] for result in result_list: cal = calculate(result) if cal: calculate_result = cal break print("cer result", result, calculate_result) return calculate_result def calculate(_list): char_dict = { "1": 1, "2": 2, "3": 3, "4": 4, "5": 5, "6": 6, "7": 7, "8": 8, "9": 9, "0": 0, "一": 1, "二": 2, "三": 3, "四": 4, "五": 5, "六": 6, "七": 7, "八": 8, "九": 9, "零": 0, "加": "加", "减": "减", "乘": "乘", "除": "除", "+": "加", "-": "减", "*": "乘", "×": "乘", "/": "除", "÷": "除", "=": "", "?": "", "上": "", "去": "", "以": "", } equation_str = "" for c in _list: equation_str += str(char_dict.get(c)) op = re.findall("加|减|乘|除", equation_str) op = list(set(op)) if len(op) != 1: return None nums = re.split("加|减|乘|除", equation_str) if len(nums) != 2: return None try: num1 = int(nums[0]) num2 = int(nums[1]) except: print("非数字!") return None op = op[0] if op == "加": result = num1 + num2 elif op == '减': result = num1 - num2 elif op == '乘': result = num1 * num2 elif op == '除': result = int(num1 / max(num2, 1)) return result if __name__ == "__main__": # _path = "../data/test/char_9.jpg" # _path = "../data/equation/38376_减_1_问_加_4_除.jpg" _paths = glob("./*.jpg") for _path in _paths: recognize(cv2.imread(_path))