import os import random from itertools import product import cv2 import numpy as np from click_captcha.model import mobile_net, cnn_net, u_net_drag, lstm_phrase, text_cnn_phrase from click_captcha.utils import pil_resize vocabulary_len = 5792 sequence_len = 6 weights_path = "./models/e08-f10.85-phrase.h5" project_dir = os.path.dirname(os.path.abspath(__file__)) + "/../" def recognize(char_list): model = lstm_phrase((sequence_len, vocabulary_len)) model.load_weights(weights_path) if len(char_list) > sequence_len: return None char_path = "../data/phrase/char.txt" with open(char_path, "r") as f: char_map_list = f.readlines() char_dict = {} char_dict_reverse = {} for i in range(len(char_map_list)): char_dict[char_map_list[i][:-1]] = i char_dict_reverse[i] = char_map_list[i][:-1] char_dict_reverse[vocabulary_len] = "" index_list = [char_dict[x] for x in char_list] products = list(product(index_list, repeat=len(index_list))) all_index_list = [] for index_list in products: index_list = list(index_list) if len(set(index_list)) != len(index_list): continue index_list = index_list + [vocabulary_len]*(sequence_len-len(index_list)) all_index_list.append(index_list) X = np.array(all_index_list) pred = model.predict(X) for i in range(len(pred)): p = all_index_list[i] print(pred[i], "".join([char_dict_reverse[x] for x in p])) decode = all_index_list[int(np.argmax(pred))] decode = "".join([char_dict_reverse[x] for x in decode]) print("".join(char_list)) print(decode) if __name__ == "__main__": text = "吃饭了" text = [x for x in text] random.shuffle(text) print("text", text) recognize(text)