inference_char.py 2.5 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283
  1. import os
  2. import re
  3. import sys
  4. import tensorflow as tf
  5. import cv2
  6. import numpy as np
  7. sys.path.append(os.path.dirname(os.path.abspath(__file__)) + "/../")
  8. from chinese_recognize.model import cnn_net, cnn_net_tiny, cnn_net_small
  9. from utils import pil_resize
  10. package_dir = os.path.abspath(os.path.dirname(__file__))
  11. image_shape = (40, 40, 1)
  12. model_path = package_dir + "/models/char_acc_0.89.h5"
  13. with open(package_dir + "/chinese_6270.txt") as f:
  14. char_list = f.readlines()
  15. char_str = "".join(char_list)
  16. char_str = re.sub("\n", "", char_str)
  17. def recognize(image_np_list, model=None, sess=None):
  18. if sess is None:
  19. sess = tf.compat.v1.Session(graph=tf.Graph())
  20. if model is None:
  21. with sess.as_default():
  22. with sess.graph.as_default():
  23. model = cnn_net_small(input_shape=image_shape)
  24. model.load_weights(model_path)
  25. if len(image_np_list) > 30:
  26. raise
  27. # 准备批次数据
  28. X = np.zeros((len(image_np_list), image_shape[0], image_shape[1], image_shape[2]))
  29. for i in range(len(image_np_list)):
  30. img = pil_resize(image_np_list[i], image_shape[0], image_shape[1])
  31. # cv2.imshow("char img", img)
  32. # cv2.waitKey(0)
  33. img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
  34. img = np.expand_dims(img, axis=-1)
  35. img = img / 255.
  36. X[i] = img
  37. with sess.as_default():
  38. with sess.graph.as_default():
  39. pred = model.predict(X)
  40. index_list = np.argmax(pred, axis=1).tolist()
  41. char_list = []
  42. for index in index_list:
  43. char_list.append(char_str[index])
  44. return char_list
  45. def recognize_no_sess(image_np_list, model=None):
  46. if model is None:
  47. model = cnn_net(input_shape=image_shape)
  48. model.load_weights(model_path)
  49. if len(image_np_list) > 30:
  50. raise
  51. # 准备批次数据
  52. X = np.zeros((len(image_np_list), image_shape[0], image_shape[1], image_shape[2]))
  53. for i in range(len(image_np_list)):
  54. img = pil_resize(image_np_list[i], image_shape[0], image_shape[1])
  55. img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
  56. img = np.expand_dims(img, axis=-1)
  57. img = img / 255.
  58. X[i] = img
  59. pred = model.predict(X)
  60. index_list = np.argmax(pred, axis=1).tolist()
  61. char_list = []
  62. for index in index_list:
  63. char_list.append(char_str[index])
  64. return char_list
  65. if __name__ == "__main__":
  66. _path = "D:/Project/captcha/data/test/char_4.jpg"
  67. print(recognize([cv2.imread(_path)]))