inference_drag.py 1.3 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152
  1. import os
  2. import cv2
  3. import numpy as np
  4. from click_captcha.model import mobile_net, cnn_net, u_net_drag
  5. from click_captcha.utils import pil_resize
  6. image_shape = (128, 256, 3)
  7. weights_path = "./models/drag_f1_0.42.h5"
  8. project_dir = os.path.dirname(os.path.abspath(__file__)) + "/../"
  9. def recognize(image_path):
  10. model = u_net_drag(input_shape=image_shape)
  11. model.load_weights(weights_path)
  12. img = cv2.imread(image_path)
  13. img = pil_resize(img, image_shape[0], image_shape[1])
  14. cv2.imshow("img", img)
  15. img = img / 255.
  16. X = np.expand_dims(img, 0)
  17. pred = model.predict(X)
  18. pred = pred[0][..., 0]
  19. pred = np.expand_dims(pred, axis=-1)
  20. pred = np.concatenate([pred]*3, axis=-1)
  21. pred = np.uint8(pred*255.)
  22. print(pred.shape)
  23. cv2.imshow("pred", pred)
  24. w = np.argmax(np.mean(np.mean(pred, axis=0), axis=-1))
  25. print(w)
  26. decode = np.zeros(pred.shape, np.uint8)
  27. decode[:, w, :] = 255
  28. cv2.imshow("decode", decode)
  29. cv2.waitKey(0)
  30. # index = int(np.argmax(pred))
  31. # with open(project_dir + "data/chinese_5710.txt") as f:
  32. # char_str = f.read()
  33. # char = char_str[index]
  34. # print("recognize chinese", char)
  35. return
  36. if __name__ == "__main__":
  37. _path = "../data/test/yolo_16.jpg"
  38. # _path = "../data/drag/3.jpg"
  39. recognize(_path)