inference_drag.py 1.7 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061
  1. import os
  2. import sys
  3. import cv2
  4. import numpy as np
  5. import tensorflow as tf
  6. sys.path.append(os.path.dirname(os.path.abspath(__file__)) + "/../")
  7. from border_recognize.model import u_net_drag
  8. from utils import pil_resize
  9. package_dir = os.path.abspath(os.path.dirname(__file__))
  10. image_shape = (128, 256, 3)
  11. model_path = package_dir + "/models/drag_f1_0.42.h5"
  12. def recognize(image_np, model=None, sess=None):
  13. if sess is None:
  14. sess = tf.compat.v1.Session(graph=tf.Graph())
  15. if model is None:
  16. with sess.as_default():
  17. with sess.graph.as_default():
  18. model = u_net_drag(input_shape=image_shape)
  19. model.load_weights(model_path)
  20. h_origin, w_origin = image_np.shape[:2]
  21. img = image_np
  22. img = pil_resize(img, image_shape[0], image_shape[1])
  23. # cv2.imshow("img", img)
  24. img = img / 255.
  25. X = np.expand_dims(img, 0)
  26. with sess.as_default():
  27. with sess.graph.as_default():
  28. pred = model.predict(X)
  29. pred = pred[0][..., 0]
  30. pred = np.expand_dims(pred, axis=-1)
  31. pred = np.concatenate([pred]*3, axis=-1)
  32. pred = np.uint8(pred*255.)
  33. # print(pred.shape)
  34. # cv2.imshow("pred", pred)
  35. w = int(np.argmax(np.mean(np.mean(pred, axis=0), axis=-1)))
  36. # print(w)
  37. # decode = np.zeros(pred.shape, np.uint8)
  38. # decode[:, w, :] = 255
  39. # cv2.imshow("decode", decode)
  40. #
  41. # img_new = np.concatenate([img[:, w:, :], img[:, :w, :]], axis=1)
  42. # cv2.imshow("img_new", img_new)
  43. # cv2.waitKey(0)
  44. # 还原
  45. w = int(w * w_origin / image_shape[1])
  46. return w
  47. if __name__ == "__main__":
  48. _path = "D:/Project/captcha/data/test/phrase_5.jpg"
  49. recognize(cv2.imread(_path))