import os import cv2 import numpy as np from click_captcha.model import mobile_net, cnn_net, u_net_drag from click_captcha.utils import pil_resize image_shape = (128, 256, 3) weights_path = "./models/drag_f1_0.42.h5" project_dir = os.path.dirname(os.path.abspath(__file__)) + "/../" def recognize(image_path): model = u_net_drag(input_shape=image_shape) model.load_weights(weights_path) img = cv2.imread(image_path) img = pil_resize(img, image_shape[0], image_shape[1]) cv2.imshow("img", img) img = img / 255. X = np.expand_dims(img, 0) pred = model.predict(X) pred = pred[0][..., 0] pred = np.expand_dims(pred, axis=-1) pred = np.concatenate([pred]*3, axis=-1) pred = np.uint8(pred*255.) print(pred.shape) cv2.imshow("pred", pred) w = np.argmax(np.mean(np.mean(pred, axis=0), axis=-1)) print(w) decode = np.zeros(pred.shape, np.uint8) decode[:, w, :] = 255 cv2.imshow("decode", decode) cv2.waitKey(0) # index = int(np.argmax(pred)) # with open(project_dir + "data/chinese_5710.txt") as f: # char_str = f.read() # char = char_str[index] # print("recognize chinese", char) return if __name__ == "__main__": _path = "../data/test/yolo_16.jpg" # _path = "../data/drag/3.jpg" recognize(_path)