inference_classify.py 1.3 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647
  1. import os
  2. import re
  3. import sys
  4. import tensorflow as tf
  5. import cv2
  6. import numpy as np
  7. sys.path.append(os.path.dirname(os.path.abspath(__file__)) + "/../")
  8. from captcha_classify.model import cnn_net_tiny
  9. from utils import pil_resize
  10. package_dir = os.path.abspath(os.path.dirname(__file__))
  11. image_shape = (128, 128, 1)
  12. class_num = 3
  13. model_path = package_dir + "/models/e262-acc0.81-classify.h5"
  14. def classify(image_np, model=None, sess=None):
  15. if sess is None:
  16. sess = tf.compat.v1.Session(graph=tf.Graph())
  17. if model is None:
  18. with sess.as_default():
  19. with sess.graph.as_default():
  20. model = cnn_net_tiny(input_shape=image_shape, output_shape=class_num)
  21. model.load_weights(model_path)
  22. img = image_np
  23. X = []
  24. img = pil_resize(img, image_shape[0], image_shape[1])
  25. img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
  26. img = np.expand_dims(img, axis=-1)
  27. img = img / 255.
  28. X.append(img)
  29. X = np.array(X)
  30. with sess.as_default():
  31. with sess.graph.as_default():
  32. pred = model.predict(X)
  33. cls = int(np.argmax(pred))
  34. print("cac result", cls)
  35. return cls
  36. if __name__ == "__main__":
  37. _path = "D:/Project/captcha/data/test/char_4.jpg"
  38. print(classify([cv2.imread(_path)]))