import os import random import sys from itertools import product import tensorflow as tf import numpy as np sys.path.append(os.path.dirname(os.path.abspath(__file__)) + "/../") from chinese_order.model import lstm_phrase, text_cnn_phrase package_dir = os.path.abspath(os.path.dirname(__file__)) vocabulary_len = 5792 sequence_len = 6 model_path = package_dir + "/models/phrase_f1_0.85.h5" char_path = package_dir + "/char.txt" with open(char_path, "r") as f: char_map_list = f.readlines() def recognize(char_list, model=None, sess=None): if sess is None: sess = tf.compat.v1.Session(graph=tf.Graph()) if model is None: with sess.as_default(): with sess.graph.as_default(): model = lstm_phrase((sequence_len, vocabulary_len)) model.load_weights(model_path) if len(char_list) > sequence_len: return None # 生成所有可能组合的映射数据 char_dict = {} char_dict_reverse = {} for i in range(len(char_map_list)): char_dict[char_map_list[i][:-1]] = i char_dict_reverse[i] = char_map_list[i][:-1] char_dict_reverse[vocabulary_len] = "" index_list = [char_dict[x] for x in char_list] products = list(product(index_list, repeat=len(index_list))) all_index_list = [] for index_list in products: index_list = list(index_list) if len(set(index_list)) != len(index_list): continue index_list = index_list + [vocabulary_len]*(sequence_len-len(index_list)) all_index_list.append(index_list) X = np.array(all_index_list) # 预测 with sess.as_default(): with sess.graph.as_default(): pred = model.predict(X) # for i in range(len(pred)): # p = all_index_list[i] # print(pred[i], "".join([char_dict_reverse[x] for x in p])) decode = all_index_list[int(np.argmax(pred))] decode = [char_dict_reverse[x] for x in decode][:len(char_list)] # print("".join(char_list)) # print("".join(decode)) return decode if __name__ == "__main__": text = "吃饭了" text = [x for x in text] random.shuffle(text) print("text", text) recognize(text)