123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134 |
- import os
- import random
- import re
- import sys
- from glob import glob
- import cv2
- import numpy as np
- import torch
- from torch.utils.data import DataLoader
- sys.path.append(os.path.dirname(os.path.abspath(__file__)) + "/../")
- from model_torch import crnn_ctc_equation_torch6
- from pre_process_torch import EquationDataset, py_ctc_decode
- package_dir = os.path.abspath(os.path.dirname(__file__))
- model_path = package_dir + "/models/equation6_model_acc-0.853.pth"
- random.seed(42)
- device = torch.device("cpu")
- image_shape = (32, 192, 3)
- project_root = os.path.dirname(os.path.abspath(__file__)) + "/../"
- class_num = 35 + 1
- batch_size = 1
- input_len = 12
- label_len = 8
- with open(package_dir + "/equation_torch.txt", 'r', encoding='utf-8') as f:
- char_list = f.readlines()
- char_str = "".join(char_list)
- char_str = re.sub("\n", "", char_str)
- def recognize(image_np, model=None):
- if model is None:
- model = crnn_ctc_equation_torch6(class_num)
- model.load_state_dict(torch.load(model_path, map_location=torch.device(device)))
- model.eval()
- # print('type(image_np)', type(image_np))
- dataset = EquationDataset([image_np], image_shape, input_len, label_len, channel=image_shape[-1], mode=1)
- data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=False)
- calculate_result = None
- with torch.no_grad():
- for data, targets, _, _ in data_loader:
- data = data.to(device)
- data = data.to(torch.float32)
- outputs = model(data)
- result_list = py_ctc_decode(outputs)[0]
- for result in result_list:
- cal = calculate(result)
- if cal:
- calculate_result = cal
- break
- print("cer result", result, calculate_result)
- return calculate_result
- def calculate(_list):
- char_dict = {
- "1": 1,
- "2": 2,
- "3": 3,
- "4": 4,
- "5": 5,
- "6": 6,
- "7": 7,
- "8": 8,
- "9": 9,
- "0": 0,
- "一": 1,
- "二": 2,
- "三": 3,
- "四": 4,
- "五": 5,
- "六": 6,
- "七": 7,
- "八": 8,
- "九": 9,
- "零": 0,
- "加": "加",
- "减": "减",
- "乘": "乘",
- "除": "除",
- "+": "加",
- "-": "减",
- "*": "乘",
- "×": "乘",
- "/": "除",
- "÷": "除",
- "=": "",
- "?": "",
- "上": "",
- "去": "",
- "以": "",
- }
- equation_str = ""
- for c in _list:
- equation_str += str(char_dict.get(c))
- op = re.findall("加|减|乘|除", equation_str)
- op = list(set(op))
- if len(op) != 1:
- return None
- nums = re.split("加|减|乘|除", equation_str)
- if len(nums) != 2:
- return None
- try:
- num1 = int(nums[0])
- num2 = int(nums[1])
- except:
- print("非数字!")
- return None
- op = op[0]
- if op == "加":
- result = num1 + num2
- elif op == '减':
- result = num1 - num2
- elif op == '乘':
- result = num1 * num2
- elif op == '除':
- result = int(num1 / max(num2, 1))
- return result
- if __name__ == "__main__":
- # _path = "../data/test/char_9.jpg"
- # _path = "../data/equation/38376_减_1_问_加_4_除.jpg"
- _paths = glob("./*.jpg")
- for _path in _paths:
- recognize(cv2.imread(_path))
|