import os import re from glob import glob import cv2 import numpy as np from click_captcha.inference_equation_denoise import denoise from click_captcha.model import crnn_ctc_equation, ctc_decode, crnn_ctc_equation_large, crnn_ctc_equation_less, \ crnn_ctc_equation_loss from click_captcha.pre_process import eight_neighbour, connected_component, add_contrast from click_captcha.utils import pil_resize image_shape = (32, 192, 1) weights_path = "./models/e55-loss0.14-equation.h5" project_dir = os.path.dirname(os.path.abspath(__file__)) + "/../" model = crnn_ctc_equation_loss(input_shape=image_shape, class_num=35+2, is_train=False) model.load_weights(weights_path) def recognize(image_path): X = [] img = cv2.imread(image_path) img = pil_resize(img, image_shape[0], image_shape[1]) # cv2.imshow("img", img) img = denoise(img) # cv2.imshow("denoise", img) img = add_contrast(img) if img.shape[2] == 3: img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) img = np.expand_dims(img, axis=-1) # _, img = cv2.threshold(img, 110, 255, cv2.THRESH_BINARY) # cv2.imshow("contrast", img) # gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) # gray = eight_neighbour(gray, 4) # cv2.imshow("eight_neighbour", gray) # img2 = connected_component(gray) # cv2.imshow("connected_component", img2) # img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) # img = np.expand_dims(img, axis=-1) img = img / 255. X.append(img) X = np.array(X) pred = ctc_decode(X, model) print(image_path) pred = pred[0][0] with open(project_dir + "data/equation.txt", 'r') as f: char_list = f.readlines() char_str = "".join(char_list) char_str = re.sub("\n", "", char_str) result_list = [] for index in pred: index = int(index-1) # index = int(index) if index < 0: continue # print(char_str[index]) result_list.append(char_str[index]) print(result_list) label = image_path.split("_")[-1].split(".")[0] pred = calculate(result_list) print("计算结果:", pred, label) # cv2.waitKey(0) if pred is None: return 0 if int(label) == int(pred): return 1 else: return 0 def calculate1(_list): char_dict = { "1": 1, "2": 2, "3": 3, "4": 4, "5": 5, "6": 6, "7": 7, "8": 8, "9": 9, "0": 0, "一": 1, "二": 2, "三": 3, "四": 4, "五": 5, "六": 6, "七": 7, "八": 8, "九": 9, "零": 0, "加": "加", "减": "减", "乘": "乘", "除": "除", "+": "加", "-": "减", "*": "乘", "×": "乘", "/": "除", "÷": "除", "=": "", "?": "", "上": "", "去": "", "以": "", } equation_str = "" for c in _list: equation_str += str(char_dict.get(c)) op = re.findall("加|减|乘|除", equation_str) op = list(set(op)) if len(op) != 1: return None nums = re.split("加|减|乘|除", equation_str) if len(nums) != 2: return None try: num1 = int(nums[0]) num2 = int(nums[1]) except: print("非数字!") return None op = op[0] if op == "加": result = num1 + num2 elif op == '减': result = num1 - num2 elif op == '乘': result = num1 * num2 elif op == '除': result = int(num1 / max(num2, 1)) return result def calculate(_list): char_dict = { "1": 1, "2": 2, "3": 3, "4": 4, "5": 5, "6": 6, "7": 7, "8": 8, "9": 9, "0": 0, "一": 1, "二": 2, "三": 3, "四": 4, "五": 5, "六": 6, "七": 7, "八": 8, "九": 9, "零": 0, "加": "加", "减": "减", "乘": "乘", "除": "除", "+": "加", "-": "减", "*": "乘", "×": "乘", "/": "除", "÷": "除", "=": "", "?": "", "上": "", "去": "", "以": "", } equation_str = "" for c in _list: equation_str += str(char_dict.get(c)) op = re.findall("加|减|乘|除", equation_str) op = list(set(op)) if len(op) != 1: return None nums = re.split("加|减|乘|除", equation_str) if len(nums) != 2: return None try: num1 = int(nums[0]) num2 = int(nums[1]) except: print("非数字!") return None op = op[0] if op == "加": result = num1 + num2 elif op == '减': result = num1 - num2 elif op == '乘': result = num1 * num2 elif op == '除': result = int(num1 / max(num2, 1)) return result if __name__ == "__main__": # _path = "../data/test/char_9.jpg" # _path = "../data/equation/38376_减_1_问_加_4_除.jpg" _paths = glob("../data/test/FileInfo1021/*") right_num = 0 for _path in _paths: r = recognize(_path) if r: right_num += 1 print("准确率:", right_num / len(_paths), right_num, len(_paths))