inference_yolo_puzzle.py 4.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110
  1. import os
  2. import sys
  3. import tensorflow as tf
  4. import cv2
  5. from keras.layers import Lambda, Input
  6. from keras.models import Model
  7. import numpy as np
  8. sys.path.append(os.path.dirname(os.path.abspath(__file__)) + "/../")
  9. from puzzle_detect.model_260 import tiny_yolo_body
  10. from puzzle_detect.post_process import yolo_eval, letterbox_image
  11. from utils import get_classes, get_colors, draw_boxes, get_anchors, pil_resize, np2pil
  12. package_dir = os.path.abspath(os.path.dirname(__file__))
  13. model_path = package_dir + "/models/puzzle_yolo_loss_4.15.h5"
  14. anchors = get_anchors(package_dir + "/yolo_data/my_anchors_puzzle.txt")
  15. classes = get_classes(package_dir + "/yolo_data/my_classes_puzzle.txt")
  16. colors = get_colors(len(classes))
  17. image_shape = (160, 256, 3)
  18. def detect(image_np, model=None, sess=None, draw=False):
  19. if sess is None:
  20. sess = tf.compat.v1.Session(graph=tf.Graph())
  21. if model is None:
  22. with sess.as_default():
  23. with sess.graph.as_default():
  24. model = get_tiny_inference_model(anchors, len(classes), weights_path=model_path)
  25. # create image input
  26. h, w = image_np.shape[:2]
  27. best_h, best_w = image_shape[:2]
  28. image_resize = pil_resize(image_np, best_h, best_w)
  29. # image_pil = np2pil(image_resize)
  30. image_resize = cv2.cvtColor(image_resize, cv2.COLOR_BGR2GRAY)
  31. image_resize = 255. - image_resize
  32. image_resize = np.uint8(image_resize)
  33. image_resize = image_resize / 255.
  34. image_resize = np.expand_dims(image_resize, 0)
  35. image_resize = np.expand_dims(image_resize, -1)
  36. # create image shape input
  37. need_shape = np.array([best_h, best_w])
  38. need_shape = np.expand_dims(need_shape, 0)
  39. # inference data
  40. with sess.as_default():
  41. with sess.graph.as_default():
  42. out_boxes, out_scores, out_classes = model.predict([image_resize, need_shape], steps=1)
  43. print("out_boxes", out_boxes)
  44. print("out_scores", out_scores)
  45. out_boxes = out_boxes.astype(np.int32)
  46. out_classes = out_classes.astype(np.int32)
  47. # 还原
  48. out_boxes[:, 0] = h * out_boxes[:, 0] / best_h
  49. out_boxes[:, 2] = h * out_boxes[:, 2] / best_h
  50. out_boxes[:, 1] = w * out_boxes[:, 1] / best_w
  51. out_boxes[:, 3] = w * out_boxes[:, 3] / best_w
  52. image_pil = np2pil(image_np)
  53. if draw:
  54. # draw
  55. # image_np = cv2.cvtColor(np.array(image_pil), cv2.COLOR_RGB2BGR)
  56. class_names = get_classes("yolo_data/my_classes_puzzle.txt")
  57. colors = get_colors(len(class_names))
  58. image_resize, out_boxes = draw_boxes(image_pil, out_boxes, out_classes, out_scores, class_names, colors)
  59. image_np_result = cv2.cvtColor(np.array(image_resize), cv2.COLOR_RGB2BGR)
  60. cv2.imshow("result", image_np_result)
  61. cv2.waitKey(0)
  62. else:
  63. temp_boxes = []
  64. for i, c in reversed(list(enumerate(out_classes))):
  65. top, left, bottom, right = out_boxes[i]
  66. top = max(0, np.floor(top + 0.5).astype('int32'))
  67. left = max(0, np.floor(left + 0.5).astype('int32'))
  68. bottom = min(image_pil.size[1], np.floor(bottom + 0.5).astype('int32'))
  69. right = min(image_pil.size[0], np.floor(right + 0.5).astype('int32'))
  70. temp_boxes.append([(left, top), (right, bottom)])
  71. out_boxes = temp_boxes
  72. out_boxes = [[int(x[0][0]), int(x[0][1]), int(x[1][0]), int(x[1][1])] for x in out_boxes]
  73. return image_np, out_boxes, out_classes
  74. def get_tiny_inference_model(anchors, num_classes, weights_path='models/tiny_yolo_weights.h5'):
  75. """create the inference model, for Tiny YOLOv3"""
  76. image_input = Input(shape=(None, None, 1))
  77. need_shape = Input(shape=(2,), dtype='int64', name='image_shape')
  78. num_anchors = len(anchors)
  79. model_body = tiny_yolo_body(image_input, num_anchors//2, num_classes)
  80. print('Create Tiny YOLOv3 model with {} anchors and {} classes.'.format(num_anchors, num_classes))
  81. model_body.load_weights(weights_path)
  82. print('Load weights {}.'.format(weights_path))
  83. boxes, scores, classes = Lambda(yolo_eval,
  84. name='yolo_eval',
  85. arguments={'anchors': anchors,
  86. 'num_classes': num_classes}
  87. )([model_body.output, need_shape])
  88. model = Model([model_body.input, need_shape], [boxes, scores, classes])
  89. model.summary(120)
  90. return model
  91. if __name__ == '__main__':
  92. image_path = "D:/Project/captcha/data/test/yolo_12.jpg"
  93. detect(cv2.imread(image_path))