1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071 |
- import os
- import random
- import sys
- from itertools import product
- import tensorflow as tf
- import numpy as np
- sys.path.append(os.path.dirname(os.path.abspath(__file__)) + "/../")
- from chinese_order.model import lstm_phrase, text_cnn_phrase
- package_dir = os.path.abspath(os.path.dirname(__file__))
- vocabulary_len = 5792
- sequence_len = 6
- model_path = package_dir + "/models/phrase_f1_0.85.h5"
- char_path = package_dir + "/char.txt"
- with open(char_path, "r") as f:
- char_map_list = f.readlines()
- def recognize(char_list, model=None, sess=None):
- if sess is None:
- sess = tf.compat.v1.Session(graph=tf.Graph())
- if model is None:
- with sess.as_default():
- with sess.graph.as_default():
- model = lstm_phrase((sequence_len, vocabulary_len))
- model.load_weights(model_path)
- if len(char_list) > sequence_len:
- return None
- # 生成所有可能组合的映射数据
- char_dict = {}
- char_dict_reverse = {}
- for i in range(len(char_map_list)):
- char_dict[char_map_list[i][:-1]] = i
- char_dict_reverse[i] = char_map_list[i][:-1]
- char_dict_reverse[vocabulary_len] = ""
- index_list = [char_dict[x] for x in char_list]
- products = list(product(index_list, repeat=len(index_list)))
- all_index_list = []
- for index_list in products:
- index_list = list(index_list)
- if len(set(index_list)) != len(index_list):
- continue
- index_list = index_list + [vocabulary_len]*(sequence_len-len(index_list))
- all_index_list.append(index_list)
- X = np.array(all_index_list)
- # 预测
- with sess.as_default():
- with sess.graph.as_default():
- pred = model.predict(X)
- # for i in range(len(pred)):
- # p = all_index_list[i]
- # print(pred[i], "".join([char_dict_reverse[x] for x in p]))
- decode = all_index_list[int(np.argmax(pred))]
- decode = [char_dict_reverse[x] for x in decode][:len(char_list)]
- # print("".join(char_list))
- # print("".join(decode))
- return decode
- if __name__ == "__main__":
- text = "吃饭了"
- text = [x for x in text]
- random.shuffle(text)
- print("text", text)
- recognize(text)
|