inference_phrase.py 2.2 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071
  1. import os
  2. import random
  3. import sys
  4. from itertools import product
  5. import tensorflow as tf
  6. import numpy as np
  7. sys.path.append(os.path.dirname(os.path.abspath(__file__)) + "/../")
  8. from chinese_order.model import lstm_phrase, text_cnn_phrase
  9. package_dir = os.path.abspath(os.path.dirname(__file__))
  10. vocabulary_len = 5792
  11. sequence_len = 6
  12. model_path = package_dir + "/models/phrase_f1_0.85.h5"
  13. char_path = package_dir + "/char.txt"
  14. with open(char_path, "r") as f:
  15. char_map_list = f.readlines()
  16. def recognize(char_list, model=None, sess=None):
  17. if sess is None:
  18. sess = tf.compat.v1.Session(graph=tf.Graph())
  19. if model is None:
  20. with sess.as_default():
  21. with sess.graph.as_default():
  22. model = lstm_phrase((sequence_len, vocabulary_len))
  23. model.load_weights(model_path)
  24. if len(char_list) > sequence_len:
  25. return None
  26. # 生成所有可能组合的映射数据
  27. char_dict = {}
  28. char_dict_reverse = {}
  29. for i in range(len(char_map_list)):
  30. char_dict[char_map_list[i][:-1]] = i
  31. char_dict_reverse[i] = char_map_list[i][:-1]
  32. char_dict_reverse[vocabulary_len] = ""
  33. index_list = [char_dict[x] for x in char_list]
  34. products = list(product(index_list, repeat=len(index_list)))
  35. all_index_list = []
  36. for index_list in products:
  37. index_list = list(index_list)
  38. if len(set(index_list)) != len(index_list):
  39. continue
  40. index_list = index_list + [vocabulary_len]*(sequence_len-len(index_list))
  41. all_index_list.append(index_list)
  42. X = np.array(all_index_list)
  43. # 预测
  44. with sess.as_default():
  45. with sess.graph.as_default():
  46. pred = model.predict(X)
  47. # for i in range(len(pred)):
  48. # p = all_index_list[i]
  49. # print(pred[i], "".join([char_dict_reverse[x] for x in p]))
  50. decode = all_index_list[int(np.argmax(pred))]
  51. decode = [char_dict_reverse[x] for x in decode][:len(char_list)]
  52. # print("".join(char_list))
  53. # print("".join(decode))
  54. return decode
  55. if __name__ == "__main__":
  56. text = "吃饭了"
  57. text = [x for x in text]
  58. random.shuffle(text)
  59. print("text", text)
  60. recognize(text)