import os import re import sys import tensorflow as tf import cv2 import numpy as np sys.path.append(os.path.dirname(os.path.abspath(__file__)) + "/../") from chinese_recognize.model import cnn_net, cnn_net_tiny, cnn_net_small from utils import pil_resize package_dir = os.path.abspath(os.path.dirname(__file__)) image_shape = (40, 40, 1) model_path = package_dir + "/models/char_acc_0.89.h5" with open(package_dir + "/chinese_6270.txt") as f: char_list = f.readlines() char_str = "".join(char_list) char_str = re.sub("\n", "", char_str) def recognize(image_np_list, model=None, sess=None): if sess is None: sess = tf.compat.v1.Session(graph=tf.Graph()) if model is None: with sess.as_default(): with sess.graph.as_default(): model = cnn_net_small(input_shape=image_shape) model.load_weights(model_path) if len(image_np_list) > 30: raise # 准备批次数据 X = np.zeros((len(image_np_list), image_shape[0], image_shape[1], image_shape[2])) for i in range(len(image_np_list)): img = pil_resize(image_np_list[i], image_shape[0], image_shape[1]) # cv2.imshow("char img", img) # cv2.waitKey(0) img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) img = np.expand_dims(img, axis=-1) img = img / 255. X[i] = img with sess.as_default(): with sess.graph.as_default(): pred = model.predict(X) index_list = np.argmax(pred, axis=1).tolist() char_list = [] for index in index_list: char_list.append(char_str[index]) return char_list def recognize_no_sess(image_np_list, model=None): if model is None: model = cnn_net(input_shape=image_shape) model.load_weights(model_path) if len(image_np_list) > 30: raise # 准备批次数据 X = np.zeros((len(image_np_list), image_shape[0], image_shape[1], image_shape[2])) for i in range(len(image_np_list)): img = pil_resize(image_np_list[i], image_shape[0], image_shape[1]) img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) img = np.expand_dims(img, axis=-1) img = img / 255. X[i] = img pred = model.predict(X) index_list = np.argmax(pred, axis=1).tolist() char_list = [] for index in index_list: char_list.append(char_str[index]) return char_list if __name__ == "__main__": _path = "D:/Project/captcha/data/test/char_4.jpg" print(recognize([cv2.imread(_path)]))