inference_yolo_char.py 5.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138
  1. import os
  2. import sys
  3. import tensorflow as tf
  4. import cv2
  5. from keras.layers import Lambda, Input
  6. from keras.models import Model
  7. import numpy as np
  8. sys.path.append(os.path.dirname(os.path.abspath(__file__)) + "/../")
  9. from chinese_detect.model_260 import tiny_yolo_body
  10. from chinese_detect.post_process import yolo_eval
  11. from utils import get_classes, get_colors, draw_boxes, get_anchors, pil_resize, np2pil, limit_image_size
  12. package_dir = os.path.abspath(os.path.dirname(__file__))
  13. model_path = package_dir + "/models/char_yolo_loss_16.99.h5"
  14. anchors = get_anchors(package_dir + "/yolo_data/my_anchors_new.txt")
  15. classes = get_classes(package_dir + "/yolo_data/my_classes.txt")
  16. colors = get_colors(len(classes))
  17. image_shape = (160, 256, 3)
  18. tips_shape = (40, 160, 3)
  19. def detect(image_np, model=None, sess=None, draw=False, is_tips=0):
  20. if sess is None:
  21. sess = tf.compat.v1.Session(graph=tf.Graph())
  22. if model is None:
  23. with sess.as_default():
  24. with sess.graph.as_default():
  25. model = get_tiny_inference_model(anchors, len(classes), weights_path=model_path)
  26. # create image input
  27. image_np = limit_image_size(image_np)
  28. h, w = image_np.shape[:2]
  29. # print("h, w", h, w)
  30. best_h, best_w = image_shape[:2]
  31. # 如是提示图片需paste
  32. if is_tips:
  33. image_tips = pil_resize(image_np, tips_shape[0], tips_shape[1])
  34. image_resize = np.zeros(image_shape, dtype=np.uint8)
  35. image_resize[:tips_shape[0], :tips_shape[1], :] = image_tips[:, :, :]
  36. else:
  37. image_resize = pil_resize(image_np, best_h, best_w)
  38. # cv2.imshow("image_resize", image_resize)
  39. # cv2.waitKey(0)
  40. # image_pil = np2pil(image_resize)
  41. image_resize = np.array(image_resize, dtype='float32')
  42. image_resize = image_resize.astype('float32') / 255.
  43. image_resize = np.expand_dims(image_resize, 0)
  44. # create image shape input
  45. need_shape = np.array([best_h, best_w])
  46. need_shape = np.expand_dims(need_shape, 0)
  47. # inference data
  48. with sess.as_default():
  49. with sess.graph.as_default():
  50. out_boxes, out_scores, out_classes = model.predict([image_resize, need_shape], steps=1)
  51. # print("image_size", need_shape)
  52. print("out_boxes", out_boxes)
  53. print("out_scores", out_scores)
  54. # print("out_classes", out_classes)
  55. out_boxes = out_boxes.astype(np.int32)
  56. out_classes = out_classes.astype(np.int32)
  57. # 还原
  58. if is_tips:
  59. out_boxes[:, 0] = h * out_boxes[:, 0] / tips_shape[0]
  60. out_boxes[:, 2] = h * out_boxes[:, 2] / tips_shape[0]
  61. out_boxes[:, 1] = w * out_boxes[:, 1] / tips_shape[1]
  62. out_boxes[:, 3] = w * out_boxes[:, 3] / tips_shape[1]
  63. else:
  64. out_boxes[:, 0] = h * out_boxes[:, 0] / best_h
  65. out_boxes[:, 2] = h * out_boxes[:, 2] / best_h
  66. out_boxes[:, 1] = w * out_boxes[:, 1] / best_w
  67. out_boxes[:, 3] = w * out_boxes[:, 3] / best_w
  68. image_pil = np2pil(image_np)
  69. if draw:
  70. # draw
  71. class_names = get_classes("yolo_data/my_classes.txt")
  72. colors = get_colors(len(class_names))
  73. image_resize, out_boxes = draw_boxes(image_pil, out_boxes, out_classes, out_scores, class_names, colors)
  74. image_np_result = cv2.cvtColor(np.array(image_resize), cv2.COLOR_RGB2BGR)
  75. cv2.imshow("result", image_np_result)
  76. cv2.waitKey(0)
  77. else:
  78. temp_boxes = []
  79. for i, c in reversed(list(enumerate(out_classes))):
  80. top, left, bottom, right = out_boxes[i]
  81. top = max(0, np.floor(top + 0.5).astype('int32'))
  82. left = max(0, np.floor(left + 0.5).astype('int32'))
  83. bottom = min(image_pil.size[1], np.floor(bottom + 0.5).astype('int32'))
  84. right = min(image_pil.size[0], np.floor(right + 0.5).astype('int32'))
  85. temp_boxes.append([(left, top), (right, bottom)])
  86. out_boxes = temp_boxes
  87. # 加大box
  88. threshold = 2
  89. out_boxes = [[max(int(x[0][0]-threshold), 0),
  90. max(int(x[0][1]-threshold), 0),
  91. min(int(x[1][0]+threshold), w),
  92. min(int(x[1][1]+threshold), h)] for x in out_boxes]
  93. out_boxes.sort(key=lambda x: (x[0], x[1], x[2], x[3]))
  94. return image_np, out_boxes, out_classes
  95. def get_tiny_inference_model(anchors, num_classes, weights_path='models/tiny_yolo_weights.h5'):
  96. """create the inference model, for Tiny YOLOv3"""
  97. image_input = Input(shape=(None, None, 3))
  98. need_shape = Input(shape=(2,), dtype='int64', name='image_shape')
  99. num_anchors = len(anchors)
  100. model_body = tiny_yolo_body(image_input, num_anchors//2, num_classes)
  101. print('Create Tiny YOLOv3 model with {} anchors and {} classes.'.format(num_anchors, num_classes))
  102. model_body.load_weights(weights_path)
  103. print('Load weights {}.'.format(weights_path))
  104. boxes, scores, classes = Lambda(yolo_eval,
  105. name='yolo_eval',
  106. arguments={'anchors': anchors,
  107. 'num_classes': num_classes}
  108. )([model_body.output, need_shape])
  109. model = Model([model_body.input, need_shape], [boxes, scores, classes])
  110. # model.summary(120)
  111. return model
  112. if __name__ == '__main__':
  113. image_path = "D:/Project/captcha/data/test/yolo_3.jpg"
  114. _img = cv2.imread(image_path)
  115. cv2.imshow("origin_image", _img)
  116. _, boxes, _ = detect(_img, is_tips=1, draw=False)
  117. for box in boxes:
  118. cv2.imshow("sub", _img[box[1]:box[3], box[0]:box[2], :])
  119. cv2.waitKey(0)