inference_drag.py 1.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960
  1. import os
  2. import cv2
  3. import numpy as np
  4. import tensorflow as tf
  5. from border_recognize.model import u_net_drag
  6. from utils import pil_resize
  7. image_shape = (128, 256, 3)
  8. model_path = "./models/drag_f1_0.42.h5"
  9. project_dir = os.path.dirname(os.path.abspath(__file__)) + "/../"
  10. def recognize(image_np, model=None, sess=None):
  11. if sess is None:
  12. sess = tf.compat.v1.Session(graph=tf.Graph())
  13. if model is None:
  14. with sess.as_default():
  15. with sess.graph.as_default():
  16. model = u_net_drag(input_shape=image_shape)
  17. model.load_weights(model_path)
  18. h_origin, w_origin = image_np.shape[:2]
  19. img = image_np
  20. img = pil_resize(img, image_shape[0], image_shape[1])
  21. # cv2.imshow("img", img)
  22. img = img / 255.
  23. X = np.expand_dims(img, 0)
  24. with sess.as_default():
  25. with sess.graph.as_default():
  26. pred = model.predict(X)
  27. pred = pred[0][..., 0]
  28. pred = np.expand_dims(pred, axis=-1)
  29. pred = np.concatenate([pred]*3, axis=-1)
  30. pred = np.uint8(pred*255.)
  31. # print(pred.shape)
  32. # cv2.imshow("pred", pred)
  33. w = int(np.argmax(np.mean(np.mean(pred, axis=0), axis=-1)))
  34. # print(w)
  35. # decode = np.zeros(pred.shape, np.uint8)
  36. # decode[:, w, :] = 255
  37. # cv2.imshow("decode", decode)
  38. #
  39. # img_new = np.concatenate([img[:, w:, :], img[:, :w, :]], axis=1)
  40. # cv2.imshow("img_new", img_new)
  41. # cv2.waitKey(0)
  42. # 还原
  43. w = int(w * w_origin / image_shape[1])
  44. return w
  45. if __name__ == "__main__":
  46. _path = "../data/test/yolo_16.jpg"
  47. # _path = "../data/drag/3.jpg"
  48. recognize(_path)