inference_yolo_char.py 4.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106
  1. import os
  2. import tensorflow as tf
  3. import cv2
  4. from keras.layers import Lambda, Input
  5. from keras.models import Model
  6. import numpy as np
  7. from chinese_detect.model_260 import tiny_yolo_body
  8. from chinese_detect.post_process import yolo_eval
  9. from utils import get_classes, get_colors, draw_boxes, get_anchors, pil_resize, np2pil
  10. _dir = os.path.abspath(os.path.dirname(os.path.abspath(__file__)))
  11. model_path = _dir + "/models/char_yolo_loss_39.90.h5"
  12. anchors = get_anchors(_dir + "/yolo_data/my_anchors.txt")
  13. classes = get_classes(_dir + "/yolo_data/my_classes.txt")
  14. colors = get_colors(len(classes))
  15. image_shape = (160, 256, 3)
  16. def detect(image_np, model=None, sess=None, draw=True):
  17. if sess is None:
  18. sess = tf.compat.v1.Session(graph=tf.Graph())
  19. if model is None:
  20. with sess.as_default():
  21. with sess.graph.as_default():
  22. model = get_tiny_inference_model(anchors, len(classes), weights_path=model_path)
  23. # create image input
  24. h, w = image_np.shape[:2]
  25. best_h, best_w = image_shape[:2]
  26. image_resize = pil_resize(image_np, best_h, best_w)
  27. # image_pil = np2pil(image_resize)
  28. image_resize = np.array(image_resize, dtype='float32')
  29. image_resize = image_resize.astype('float32') / 255.
  30. image_resize = np.expand_dims(image_resize, 0)
  31. # create image shape input
  32. need_shape = np.array([best_h, best_w])
  33. need_shape = np.expand_dims(need_shape, 0)
  34. # inference data
  35. with sess.as_default():
  36. with sess.graph.as_default():
  37. out_boxes, out_scores, out_classes = model.predict([image_resize, need_shape], steps=1)
  38. # print("image_size", need_shape)
  39. print("out_boxes", out_boxes)
  40. print("out_scores", out_scores)
  41. # print("out_classes", out_classes)
  42. out_boxes = out_boxes.astype(np.int32)
  43. out_classes = out_classes.astype(np.int32)
  44. # 还原
  45. out_boxes[:, 0] = h * out_boxes[:, 0] / best_h
  46. out_boxes[:, 2] = h * out_boxes[:, 2] / best_h
  47. out_boxes[:, 1] = w * out_boxes[:, 1] / best_w
  48. out_boxes[:, 3] = w * out_boxes[:, 3] / best_w
  49. image_pil = np2pil(image_np)
  50. if draw:
  51. # draw
  52. class_names = get_classes("yolo_data/my_classes.txt")
  53. colors = get_colors(len(class_names))
  54. image_resize, out_boxes = draw_boxes(image_pil, out_boxes, out_classes, out_scores, class_names, colors)
  55. image_np_result = cv2.cvtColor(np.array(image_resize), cv2.COLOR_RGB2BGR)
  56. cv2.imshow("result", image_np_result)
  57. cv2.waitKey(0)
  58. else:
  59. temp_boxes = []
  60. for i, c in reversed(list(enumerate(out_classes))):
  61. top, left, bottom, right = out_boxes[i]
  62. top = max(0, np.floor(top + 0.5).astype('int32'))
  63. left = max(0, np.floor(left + 0.5).astype('int32'))
  64. bottom = min(image_pil.size[1], np.floor(bottom + 0.5).astype('int32'))
  65. right = min(image_pil.size[0], np.floor(right + 0.5).astype('int32'))
  66. temp_boxes.append([(left, top), (right, bottom)])
  67. out_boxes = temp_boxes
  68. out_boxes = [[int(x[0][0]), int(x[0][1]), int(x[1][0]), int(x[1][1])] for x in out_boxes]
  69. return image_np, out_boxes, out_classes
  70. def get_tiny_inference_model(anchors, num_classes, weights_path='models/tiny_yolo_weights.h5'):
  71. """create the inference model, for Tiny YOLOv3"""
  72. image_input = Input(shape=(None, None, 3))
  73. image_shape = Input(shape=(2,), dtype='int64', name='image_shape')
  74. num_anchors = len(anchors)
  75. model_body = tiny_yolo_body(image_input, num_anchors//2, num_classes)
  76. print('Create Tiny YOLOv3 model with {} anchors and {} classes.'.format(num_anchors, num_classes))
  77. model_body.load_weights(weights_path)
  78. print('Load weights {}.'.format(weights_path))
  79. boxes, scores, classes = Lambda(yolo_eval,
  80. name='yolo_eval',
  81. arguments={'anchors': anchors,
  82. 'num_classes': num_classes}
  83. )([model_body.output, image_shape])
  84. model = Model([model_body.input, image_shape], [boxes, scores, classes])
  85. model.summary(120)
  86. return model
  87. if __name__ == '__main__':
  88. image_path = "D:/Project/captcha/data/test/phrase_5.jpg"
  89. detect(cv2.imread(image_path))