123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132 |
- import os
- import re
- import sys
- from glob import glob
- import cv2
- import numpy as np
- import tensorflow as tf
- sys.path.append(os.path.dirname(os.path.abspath(__file__)) + "/../")
- from model import crnn_ctc_equation, ctc_decode, crnn_ctc_equation_large
- from utils import pil_resize, add_contrast
- package_dir = os.path.abspath(os.path.dirname(__file__))
- image_shape = (32, 192, 1)
- model_path = package_dir + "/models/e55-loss0.14-equation.h5"
- with open(package_dir + "/equation.txt", 'r', encoding='utf-8') as f:
- char_list = f.readlines()
- char_str = "".join(char_list)
- char_str = re.sub("\n", "", char_str)
- def recognize(image_np, model=None, sess=None):
- if sess is None:
- sess = tf.compat.v1.Session(graph=tf.Graph())
- if model is None:
- with sess.as_default():
- with sess.graph.as_default():
- model = crnn_ctc_equation_large(input_shape=image_shape, class_num=35+2, is_train=False)
- model.load_weights(model_path)
- img = image_np
- img = pil_resize(img, image_shape[0], image_shape[1])
- if img.shape[2] == 3:
- img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
- img = np.expand_dims(img, axis=-1)
- img = add_contrast(img)
- # cv2.imshow("contrast", img)
- X = []
- img = img / 255.
- X.append(img)
- X = np.array(X)
- with sess.as_default():
- with sess.graph.as_default():
- pred = ctc_decode(X, model)
- result_list = []
- for index in pred:
- index = int(index-1)
- if index < 0:
- continue
- result_list.append(char_str[index])
- # print(char_str[index])
- # cv2.waitKey(0)
- print("cer result", result_list)
- return calculate(result_list)
- def calculate(_list):
- char_dict = {
- "1": 1,
- "2": 2,
- "3": 3,
- "4": 4,
- "5": 5,
- "6": 6,
- "7": 7,
- "8": 8,
- "9": 9,
- "0": 0,
- "一": 1,
- "二": 2,
- "三": 3,
- "四": 4,
- "五": 5,
- "六": 6,
- "七": 7,
- "八": 8,
- "九": 9,
- "零": 0,
- "加": "加",
- "减": "减",
- "乘": "乘",
- "除": "除",
- "+": "加",
- "-": "减",
- "*": "乘",
- "×": "乘",
- "/": "除",
- "÷": "除",
- "=": "",
- "?": "",
- "上": "",
- "去": "",
- "以": "",
- }
- equation_str = ""
- for c in _list:
- equation_str += str(char_dict.get(c))
- op = re.findall("加|减|乘|除", equation_str)
- op = list(set(op))
- if len(op) != 1:
- return None
- nums = re.split("加|减|乘|除", equation_str)
- if len(nums) != 2:
- return None
- try:
- num1 = int(nums[0])
- num2 = int(nums[1])
- except:
- print("非数字!")
- return None
- op = op[0]
- if op == "加":
- result = num1 + num2
- elif op == '减':
- result = num1 - num2
- elif op == '乘':
- result = num1 * num2
- elif op == '除':
- result = int(num1 / max(num2, 1))
- return result
- if __name__ == "__main__":
- # _path = "../data/test/char_9.jpg"
- # _path = "../data/equation/38376_减_1_问_加_4_除.jpg"
- _paths = glob("../data/test/FileInfo1021/*")
- for _path in _paths:
- recognize(_path)
|