import os import re import sys import tensorflow as tf import cv2 import numpy as np sys.path.append(os.path.dirname(os.path.abspath(__file__)) + "/../") from captcha_classify.model import cnn_net_tiny from utils import pil_resize package_dir = os.path.abspath(os.path.dirname(__file__)) image_shape = (128, 128, 1) class_num = 3 model_path = package_dir + "/models/e262-acc0.81-classify.h5" def classify(image_np, model=None, sess=None): if sess is None: sess = tf.compat.v1.Session(graph=tf.Graph()) if model is None: with sess.as_default(): with sess.graph.as_default(): model = cnn_net_tiny(input_shape=image_shape, output_shape=class_num) model.load_weights(model_path) img = image_np X = [] img = pil_resize(img, image_shape[0], image_shape[1]) img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) img = np.expand_dims(img, axis=-1) img = img / 255. X.append(img) X = np.array(X) with sess.as_default(): with sess.graph.as_default(): pred = model.predict(X) cls = int(np.argmax(pred)) print("cac result", cls) return cls if __name__ == "__main__": _path = "D:/Project/captcha/data/test/char_4.jpg" print(classify([cv2.imread(_path)]))