inference_char.py 2.1 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576
  1. import os
  2. import re
  3. from glob import glob
  4. import cv2
  5. import numpy as np
  6. from click_captcha.model import mobile_net, cnn_net, cnn_net_tiny, cnn_net_small
  7. from click_captcha.utils import pil_resize
  8. image_shape = (40, 40, 1)
  9. weights_path = "./models/e01-acc0.83-char.h5"
  10. project_dir = os.path.dirname(os.path.abspath(__file__)) + "/../"
  11. def recognize(image_path):
  12. model = cnn_net_small(input_shape=image_shape)
  13. model.load_weights(weights_path)
  14. paths = glob("../data/test/char_*.jpg")
  15. X = []
  16. for image_path in paths:
  17. print(image_path)
  18. img = cv2.imread(image_path)
  19. img = pil_resize(img, image_shape[0], image_shape[1])
  20. # cv2.imshow("img", img)
  21. # cv2.waitKey(0)
  22. img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
  23. img = np.expand_dims(img, axis=-1)
  24. img = img / 255.
  25. # X = np.expand_dims(img, 0)
  26. X.append(img)
  27. X = np.array(X)
  28. print("X.shape", X.shape)
  29. pred = model.predict(X)
  30. print("pred.shape", pred.shape)
  31. with open(project_dir + "data/chinese_6270.txt", 'r') as f:
  32. char_list = f.readlines()
  33. char_str = "".join(char_list)
  34. char_str = re.sub("\n", "", char_str)
  35. for p in pred:
  36. index = int(np.argmax(p))
  37. print(char_str[index], p[index])
  38. # index_list = []
  39. # prob_list = []
  40. # for i in range(5):
  41. # index = int(np.argmax(pred))
  42. # index_list.append(index)
  43. # prob_list.append(np.max(pred))
  44. # pred = np.delete(pred, index)
  45. #
  46. # index = index_list[0]
  47. # print("index_list", index_list)
  48. # print("index", index)
  49. # with open(project_dir + "data/chinese_6270.txt", 'r') as f:
  50. # char_list = f.readlines()
  51. # char_str = "".join(char_list)
  52. # char_str = re.sub("\n", "", char_str)
  53. # char = char_str[index]
  54. # print("recognize chinese", char, prob_list[0])
  55. #
  56. # for i in range(1, len(index_list)):
  57. # print("possible chinese", i, char_str[index_list[i]], prob_list[i])
  58. return
  59. if __name__ == "__main__":
  60. # _path = "../data/test/char_9.jpg"
  61. _path = "../data/click/80_70_2.jpg"
  62. # _path = "../data/click/2019_73_1.jpg"
  63. recognize(_path)