inference_phrase.py 1.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960
  1. import os
  2. import random
  3. from itertools import product
  4. import cv2
  5. import numpy as np
  6. from click_captcha.model import mobile_net, cnn_net, u_net_drag, lstm_phrase, text_cnn_phrase
  7. from click_captcha.utils import pil_resize
  8. vocabulary_len = 5792
  9. sequence_len = 6
  10. weights_path = "./models/e08-f10.85-phrase.h5"
  11. project_dir = os.path.dirname(os.path.abspath(__file__)) + "/../"
  12. def recognize(char_list):
  13. model = lstm_phrase((sequence_len, vocabulary_len))
  14. model.load_weights(weights_path)
  15. if len(char_list) > sequence_len:
  16. return None
  17. char_path = "../data/phrase/char.txt"
  18. with open(char_path, "r") as f:
  19. char_map_list = f.readlines()
  20. char_dict = {}
  21. char_dict_reverse = {}
  22. for i in range(len(char_map_list)):
  23. char_dict[char_map_list[i][:-1]] = i
  24. char_dict_reverse[i] = char_map_list[i][:-1]
  25. char_dict_reverse[vocabulary_len] = ""
  26. index_list = [char_dict[x] for x in char_list]
  27. products = list(product(index_list, repeat=len(index_list)))
  28. all_index_list = []
  29. for index_list in products:
  30. index_list = list(index_list)
  31. if len(set(index_list)) != len(index_list):
  32. continue
  33. index_list = index_list + [vocabulary_len]*(sequence_len-len(index_list))
  34. all_index_list.append(index_list)
  35. X = np.array(all_index_list)
  36. pred = model.predict(X)
  37. for i in range(len(pred)):
  38. p = all_index_list[i]
  39. print(pred[i], "".join([char_dict_reverse[x] for x in p]))
  40. decode = all_index_list[int(np.argmax(pred))]
  41. decode = "".join([char_dict_reverse[x] for x in decode])
  42. print("".join(char_list))
  43. print(decode)
  44. if __name__ == "__main__":
  45. text = "吃饭了"
  46. text = [x for x in text]
  47. random.shuffle(text)
  48. print("text", text)
  49. recognize(text)