inference_yolo_char.py 5.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135
  1. import os
  2. import sys
  3. import tensorflow as tf
  4. import cv2
  5. from keras.layers import Lambda, Input
  6. from keras.models import Model
  7. import numpy as np
  8. sys.path.append(os.path.dirname(os.path.abspath(__file__)) + "/../")
  9. from chinese_detect.model_260 import tiny_yolo_body
  10. from chinese_detect.post_process import yolo_eval
  11. from utils import get_classes, get_colors, draw_boxes, get_anchors, pil_resize, np2pil
  12. package_dir = os.path.abspath(os.path.dirname(__file__))
  13. model_path = package_dir + "/models/char_yolo_loss_39.90.h5"
  14. anchors = get_anchors(package_dir + "/yolo_data/my_anchors.txt")
  15. classes = get_classes(package_dir + "/yolo_data/my_classes.txt")
  16. colors = get_colors(len(classes))
  17. image_shape = (160, 256, 3)
  18. tips_shape = (40, 160, 3)
  19. def detect(image_np, model=None, sess=None, draw=False, is_tips=0):
  20. if sess is None:
  21. sess = tf.compat.v1.Session(graph=tf.Graph())
  22. if model is None:
  23. with sess.as_default():
  24. with sess.graph.as_default():
  25. model = get_tiny_inference_model(anchors, len(classes), weights_path=model_path)
  26. # create image input
  27. h, w = image_np.shape[:2]
  28. # print("h, w", h, w)
  29. best_h, best_w = image_shape[:2]
  30. # 如是提示图片需paste
  31. if is_tips:
  32. image_tips = pil_resize(image_np, tips_shape[0], tips_shape[1])
  33. image_resize = np.zeros(image_shape, dtype=np.uint8)
  34. image_resize[:tips_shape[0], :tips_shape[1], :] = image_tips[:, :, :]
  35. else:
  36. image_resize = pil_resize(image_np, best_h, best_w)
  37. # image_pil = np2pil(image_resize)
  38. image_resize = np.array(image_resize, dtype='float32')
  39. image_resize = image_resize.astype('float32') / 255.
  40. image_resize = np.expand_dims(image_resize, 0)
  41. # create image shape input
  42. need_shape = np.array([best_h, best_w])
  43. need_shape = np.expand_dims(need_shape, 0)
  44. # inference data
  45. with sess.as_default():
  46. with sess.graph.as_default():
  47. out_boxes, out_scores, out_classes = model.predict([image_resize, need_shape], steps=1)
  48. # print("image_size", need_shape)
  49. print("out_boxes", out_boxes)
  50. print("out_scores", out_scores)
  51. # print("out_classes", out_classes)
  52. out_boxes = out_boxes.astype(np.int32)
  53. out_classes = out_classes.astype(np.int32)
  54. # 还原
  55. if is_tips:
  56. out_boxes[:, 0] = h * out_boxes[:, 0] / tips_shape[0]
  57. out_boxes[:, 2] = h * out_boxes[:, 2] / tips_shape[0]
  58. out_boxes[:, 1] = w * out_boxes[:, 1] / tips_shape[1]
  59. out_boxes[:, 3] = w * out_boxes[:, 3] / tips_shape[1]
  60. else:
  61. out_boxes[:, 0] = h * out_boxes[:, 0] / best_h
  62. out_boxes[:, 2] = h * out_boxes[:, 2] / best_h
  63. out_boxes[:, 1] = w * out_boxes[:, 1] / best_w
  64. out_boxes[:, 3] = w * out_boxes[:, 3] / best_w
  65. image_pil = np2pil(image_np)
  66. if draw:
  67. # draw
  68. class_names = get_classes("yolo_data/my_classes.txt")
  69. colors = get_colors(len(class_names))
  70. image_resize, out_boxes = draw_boxes(image_pil, out_boxes, out_classes, out_scores, class_names, colors)
  71. image_np_result = cv2.cvtColor(np.array(image_resize), cv2.COLOR_RGB2BGR)
  72. cv2.imshow("result", image_np_result)
  73. cv2.waitKey(0)
  74. else:
  75. temp_boxes = []
  76. for i, c in reversed(list(enumerate(out_classes))):
  77. top, left, bottom, right = out_boxes[i]
  78. top = max(0, np.floor(top + 0.5).astype('int32'))
  79. left = max(0, np.floor(left + 0.5).astype('int32'))
  80. bottom = min(image_pil.size[1], np.floor(bottom + 0.5).astype('int32'))
  81. right = min(image_pil.size[0], np.floor(right + 0.5).astype('int32'))
  82. temp_boxes.append([(left, top), (right, bottom)])
  83. out_boxes = temp_boxes
  84. # 加大box
  85. threshold = 2
  86. out_boxes = [[max(int(x[0][0]-threshold), 0),
  87. max(int(x[0][1]-threshold), 0),
  88. min(int(x[1][0]+threshold), w),
  89. min(int(x[1][1]+threshold), h)] for x in out_boxes]
  90. out_boxes.sort(key=lambda x: (x[0], x[1], x[2], x[3]))
  91. return image_np, out_boxes, out_classes
  92. def get_tiny_inference_model(anchors, num_classes, weights_path='models/tiny_yolo_weights.h5'):
  93. """create the inference model, for Tiny YOLOv3"""
  94. image_input = Input(shape=(None, None, 3))
  95. need_shape = Input(shape=(2,), dtype='int64', name='image_shape')
  96. num_anchors = len(anchors)
  97. model_body = tiny_yolo_body(image_input, num_anchors//2, num_classes)
  98. print('Create Tiny YOLOv3 model with {} anchors and {} classes.'.format(num_anchors, num_classes))
  99. model_body.load_weights(weights_path)
  100. print('Load weights {}.'.format(weights_path))
  101. boxes, scores, classes = Lambda(yolo_eval,
  102. name='yolo_eval',
  103. arguments={'anchors': anchors,
  104. 'num_classes': num_classes}
  105. )([model_body.output, need_shape])
  106. model = Model([model_body.input, need_shape], [boxes, scores, classes])
  107. # model.summary(120)
  108. return model
  109. if __name__ == '__main__':
  110. image_path = "D:/Project/captcha/data/test/yolo_3.jpg"
  111. _img = cv2.imread(image_path)
  112. cv2.imshow("origin_image", _img)
  113. _, boxes, _ = detect(_img, is_tips=1, draw=False)
  114. for box in boxes:
  115. cv2.imshow("sub", _img[box[1]:box[3], box[0]:box[2], :])
  116. cv2.waitKey(0)