inference_yolo_char.py 3.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384
  1. import os
  2. import cv2
  3. from PIL import Image
  4. from keras.layers import Lambda, Input
  5. from keras.models import Model
  6. import numpy as np
  7. from click_captcha.model import tiny_yolo_body
  8. from click_captcha.post_process import yolo_eval, letterbox_image
  9. from click_captcha.utils import get_classes, get_colors, draw_boxes, get_anchors, pil_resize, np2pil
  10. def detect_char(image_path, model):
  11. image_np = cv2.imread(image_path)
  12. # create image input
  13. h, w = image_np.shape[:2]
  14. best_h, best_w = 160, 256
  15. image_resize = pil_resize(image_np, best_h, best_w)
  16. image_pil = np2pil(image_resize)
  17. image_resize = np.array(image_resize, dtype='float32')
  18. image_resize = image_resize.astype('float32') / 255.
  19. image_resize = np.expand_dims(image_resize, 0)
  20. # create image shape input
  21. image_shape = np.array([best_h, best_w])
  22. image_shape = np.expand_dims(image_shape, 0)
  23. # inference data
  24. out_boxes, out_scores, out_classes = model.predict([image_resize, image_shape])
  25. print("image_size", image_shape)
  26. print("out_boxes", out_boxes)
  27. print("out_scores", out_scores)
  28. print("out_classes", out_classes)
  29. out_boxes = out_boxes.astype(np.int32)
  30. out_classes = out_classes.astype(np.int32)
  31. # draw
  32. # image_np = cv2.cvtColor(np.array(image_pil), cv2.COLOR_RGB2BGR)
  33. class_names = get_classes("yolo_data/my_classes.txt")
  34. colors = get_colors(len(class_names))
  35. image_resize, out_boxes = draw_boxes(image_pil, out_boxes, out_classes, out_scores, class_names, colors)
  36. image_np_result = cv2.cvtColor(np.array(image_resize), cv2.COLOR_RGB2BGR)
  37. cv2.imshow("result", image_np_result)
  38. cv2.waitKey(0)
  39. return image_np, out_boxes, out_classes
  40. def get_tiny_inference_model(anchors, num_classes, weights_path='models/tiny_yolo_weights.h5'):
  41. """create the inference model, for Tiny YOLOv3"""
  42. image_input = Input(shape=(None, None, 3))
  43. image_shape = Input(shape=(2,), dtype='int64', name='image_shape')
  44. num_anchors = len(anchors)
  45. model_body = tiny_yolo_body(image_input, num_anchors//2, num_classes)
  46. print('Create Tiny YOLOv3 model with {} anchors and {} classes.'.format(num_anchors, num_classes))
  47. model_body.load_weights(weights_path)
  48. print('Load weights {}.'.format(weights_path))
  49. boxes, scores, classes = Lambda(yolo_eval,
  50. name='yolo_eval',
  51. arguments={'anchors': anchors,
  52. 'num_classes': num_classes}
  53. )([model_body.output, image_shape])
  54. # boxes, scores, classes = yolo_eval([model_body.output, image_shape], anchors, num_classes)
  55. model = Model([model_body.input, image_shape], [boxes, scores, classes])
  56. model.summary(120)
  57. return model
  58. if __name__ == '__main__':
  59. _dir = os.path.abspath(os.path.dirname(os.path.abspath(__file__)))
  60. model_path = _dir + "/models/char_yolo_loss_39.90.h5"
  61. anchors = get_anchors(_dir + "/yolo_data/my_anchors.txt")
  62. class_names = get_classes(_dir + "/yolo_data/my_classes.txt")
  63. colors = get_colors(len(class_names))
  64. yolo_model = get_tiny_inference_model(anchors, len(class_names), weights_path=model_path)
  65. yolo_model.load_weights(model_path)
  66. image_path = "../data/test/phrase_5.jpg"
  67. # image_path = "../data/detect/1.jpg"
  68. detect_char(image_path, yolo_model)