1234567891011121314151617181920212223242526272829303132333435363738394041424344454647 |
- import os
- import re
- import sys
- import tensorflow as tf
- import cv2
- import numpy as np
- sys.path.append(os.path.dirname(os.path.abspath(__file__)) + "/../")
- from captcha_classify.model import cnn_net_tiny
- from utils import pil_resize
- package_dir = os.path.abspath(os.path.dirname(__file__))
- image_shape = (128, 128, 1)
- class_num = 3
- model_path = package_dir + "/models/e262-acc0.81-classify.h5"
- def classify(image_np, model=None, sess=None):
- if sess is None:
- sess = tf.compat.v1.Session(graph=tf.Graph())
- if model is None:
- with sess.as_default():
- with sess.graph.as_default():
- model = cnn_net_tiny(input_shape=image_shape, output_shape=class_num)
- model.load_weights(model_path)
- img = image_np
- X = []
- img = pil_resize(img, image_shape[0], image_shape[1])
- img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
- img = np.expand_dims(img, axis=-1)
- img = img / 255.
- X.append(img)
- X = np.array(X)
- with sess.as_default():
- with sess.graph.as_default():
- pred = model.predict(X)
- cls = int(np.argmax(pred))
- print("cac result", cls)
- return cls
- if __name__ == "__main__":
- _path = "D:/Project/captcha/data/test/char_4.jpg"
- print(classify([cv2.imread(_path)]))
|