123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960 |
- import os
- import random
- from itertools import product
- import cv2
- import numpy as np
- from click_captcha.model import mobile_net, cnn_net, u_net_drag, lstm_phrase, text_cnn_phrase
- from click_captcha.utils import pil_resize
- vocabulary_len = 5792
- sequence_len = 6
- weights_path = "./models/e08-f10.85-phrase.h5"
- project_dir = os.path.dirname(os.path.abspath(__file__)) + "/../"
- def recognize(char_list):
- model = lstm_phrase((sequence_len, vocabulary_len))
- model.load_weights(weights_path)
- if len(char_list) > sequence_len:
- return None
- char_path = "../data/phrase/char.txt"
- with open(char_path, "r") as f:
- char_map_list = f.readlines()
- char_dict = {}
- char_dict_reverse = {}
- for i in range(len(char_map_list)):
- char_dict[char_map_list[i][:-1]] = i
- char_dict_reverse[i] = char_map_list[i][:-1]
- char_dict_reverse[vocabulary_len] = ""
- index_list = [char_dict[x] for x in char_list]
- products = list(product(index_list, repeat=len(index_list)))
- all_index_list = []
- for index_list in products:
- index_list = list(index_list)
- if len(set(index_list)) != len(index_list):
- continue
- index_list = index_list + [vocabulary_len]*(sequence_len-len(index_list))
- all_index_list.append(index_list)
- X = np.array(all_index_list)
- pred = model.predict(X)
- for i in range(len(pred)):
- p = all_index_list[i]
- print(pred[i], "".join([char_dict_reverse[x] for x in p]))
- decode = all_index_list[int(np.argmax(pred))]
- decode = "".join([char_dict_reverse[x] for x in decode])
- print("".join(char_list))
- print(decode)
- if __name__ == "__main__":
- text = "吃饭了"
- text = [x for x in text]
- random.shuffle(text)
- print("text", text)
- recognize(text)
|