import os import cv2 import numpy as np import tensorflow as tf from border_recognize.model import u_net_drag from utils import pil_resize image_shape = (128, 256, 3) model_path = "./models/drag_f1_0.42.h5" project_dir = os.path.dirname(os.path.abspath(__file__)) + "/../" def recognize(image_np, model=None, sess=None): if sess is None: sess = tf.compat.v1.Session(graph=tf.Graph()) if model is None: with sess.as_default(): with sess.graph.as_default(): model = u_net_drag(input_shape=image_shape) model.load_weights(model_path) h_origin, w_origin = image_np.shape[:2] img = image_np img = pil_resize(img, image_shape[0], image_shape[1]) # cv2.imshow("img", img) img = img / 255. X = np.expand_dims(img, 0) with sess.as_default(): with sess.graph.as_default(): pred = model.predict(X) pred = pred[0][..., 0] pred = np.expand_dims(pred, axis=-1) pred = np.concatenate([pred]*3, axis=-1) pred = np.uint8(pred*255.) # print(pred.shape) # cv2.imshow("pred", pred) w = int(np.argmax(np.mean(np.mean(pred, axis=0), axis=-1))) # print(w) # decode = np.zeros(pred.shape, np.uint8) # decode[:, w, :] = 255 # cv2.imshow("decode", decode) # # img_new = np.concatenate([img[:, w:, :], img[:, :w, :]], axis=1) # cv2.imshow("img_new", img_new) # cv2.waitKey(0) # 还原 w = int(w * w_origin / image_shape[1]) return w if __name__ == "__main__": _path = "../data/test/yolo_16.jpg" # _path = "../data/drag/3.jpg" recognize(_path)