1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374 |
- import os
- import sys
- import tensorflow as tf
- import cv2
- import numpy as np
- sys.path.append(os.path.dirname(os.path.abspath(__file__)) + "/../")
- from chinese_recognize.model import cnn_net
- from utils import pil_resize
- package_dir = os.path.abspath(os.path.dirname(__file__))
- image_shape = (40, 40, 3)
- model_path = package_dir + "/models/char_f1_0.93.h5"
- with open(package_dir + "/chinese_5710.txt") as f:
- char_str = f.read()
- def recognize(image_np_list, model=None, sess=None):
- if sess is None:
- sess = tf.compat.v1.Session(graph=tf.Graph())
- if model is None:
- with sess.as_default():
- with sess.graph.as_default():
- model = cnn_net(input_shape=image_shape)
- model.load_weights(model_path)
- if len(image_np_list) > 30:
- raise
- # 准备批次数据
- X = np.zeros((len(image_np_list), image_shape[0], image_shape[1], image_shape[2]))
- for i in range(len(image_np_list)):
- img = pil_resize(image_np_list[i], image_shape[0], image_shape[1])
- img = img / 255.
- X[i] = img
- with sess.as_default():
- with sess.graph.as_default():
- pred = model.predict(X)
- index_list = np.argmax(pred, axis=1).tolist()
- char_list = []
- for index in index_list:
- char_list.append(char_str[index])
- return char_list
- def recognize_no_sess(image_np_list, model=None):
- if model is None:
- model = cnn_net(input_shape=image_shape)
- model.load_weights(model_path)
- if len(image_np_list) > 30:
- raise
- # 准备批次数据
- X = np.zeros((len(image_np_list), image_shape[0], image_shape[1], image_shape[2]))
- for i in range(len(image_np_list)):
- img = pil_resize(image_np_list[i], image_shape[0], image_shape[1])
- img = img / 255.
- X[i] = img
- pred = model.predict(X)
- index_list = np.argmax(pred, axis=1).tolist()
- char_list = []
- for index in index_list:
- char_list.append(char_str[index])
- return char_list
- if __name__ == "__main__":
- _path = "D:/Project/captcha/data/test/char_6.jpg"
- print(recognize([cv2.imread(_path)]))
|