inference_char.py 1.4 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950
  1. import os
  2. import sys
  3. import tensorflow as tf
  4. import cv2
  5. import numpy as np
  6. sys.path.append(os.path.dirname(os.path.abspath(__file__)) + "/../")
  7. from chinese_recognize.model import cnn_net
  8. from utils import pil_resize
  9. package_dir = os.path.abspath(os.path.dirname(__file__))
  10. image_shape = (40, 40, 3)
  11. model_path = package_dir + "/models/char_f1_0.93.h5"
  12. with open(package_dir + "/chinese_5710.txt") as f:
  13. char_str = f.read()
  14. def recognize(image_np_list, model=None, sess=None):
  15. if sess is None:
  16. sess = tf.compat.v1.Session(graph=tf.Graph())
  17. if model is None:
  18. with sess.as_default():
  19. with sess.graph.as_default():
  20. model = cnn_net(input_shape=image_shape)
  21. model.load_weights(model_path)
  22. if len(image_np_list) > 30:
  23. raise
  24. # 准备批次数据
  25. X = np.zeros((len(image_np_list), image_shape[0], image_shape[1], image_shape[2]))
  26. for i in range(len(image_np_list)):
  27. img = pil_resize(image_np_list[i], image_shape[0], image_shape[1])
  28. img = img / 255.
  29. X[i] = img
  30. with sess.as_default():
  31. with sess.graph.as_default():
  32. pred = model.predict(X)
  33. index_list = np.argmax(pred, axis=1).tolist()
  34. char_list = []
  35. for index in index_list:
  36. char_list.append(char_str[index])
  37. return char_list
  38. if __name__ == "__main__":
  39. _path = "D:/Project/captcha/data/test/char_2.jpg"
  40. print(recognize([cv2.imread(_path)]))