model.py 9.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251
  1. import hashlib
  2. import logging
  3. import os
  4. from copy import deepcopy
  5. import torch
  6. import numpy as np
  7. import cv2
  8. import torch.nn as nn
  9. from format_convert.utils import log
  10. from botr.yolov8.module import Conv, Conv2, RepConv, ConvTranspose, DWConv, Detect, parse_model, fuse_conv_and_bn, \
  11. fuse_deconv_and_bn
  12. from botr.yolov8.yolo_utils import yaml_load, initialize_weights, smart_inference_mode, \
  13. attempt_load_one_weight, non_max_suppression, scale_boxes, LetterBox, LoadPilAndNumpy
  14. cfg_path = os.path.abspath(os.path.dirname(__file__)) + '/yolov8_model.yaml'
  15. class DetectionModel(nn.Module):
  16. """YOLOv8 detection model."""
  17. def __init__(self, cfg=cfg_path, ch=3):
  18. super().__init__()
  19. self.yaml = yaml_load(cfg) # cfg dict
  20. # Define model
  21. self.model, self.save = parse_model(deepcopy(self.yaml), ch=ch) # model, savelist
  22. self.names = {i: f'{i}' for i in range(self.yaml['nc'])} # default names dict
  23. self.inplace = True
  24. # Build strides
  25. m = self.model[-1] # Detect()
  26. if isinstance(m, Detect):
  27. s = 256 # 2x min stride
  28. m.inplace = self.inplace
  29. forward = lambda x: self.forward(x)
  30. m.stride = torch.tensor([s / x.shape[-2] for x in forward(torch.zeros(1, ch, s, s))]) # forward
  31. self.stride = m.stride
  32. m.bias_init() # only run once
  33. # Init weights, biases
  34. initialize_weights(self)
  35. def is_fused(self, thresh=10):
  36. """
  37. Check if the model has less than a certain threshold of BatchNorm layers.
  38. Args:
  39. thresh (int, optional): The threshold number of BatchNorm layers. Default is 10.
  40. Returns:
  41. (bool): True if the number of BatchNorm layers in the model is less than the threshold, False otherwise.
  42. """
  43. bn = tuple(v for k, v in nn.__dict__.items() if 'Norm' in k) # normalization layers, i.e. BatchNorm2d()
  44. return sum(isinstance(v, bn) for v in self.modules()) < thresh # True if < 'thresh' BatchNorm layers in model
  45. def fuse(self):
  46. """
  47. Fuse the `Conv2d()` and `BatchNorm2d()` layers of the model into a single layer, in order to improve the
  48. computation efficiency.
  49. Returns:
  50. (nn.Module): The fused model is returned.
  51. """
  52. if not self.is_fused():
  53. for m in self.model.modules():
  54. if isinstance(m, (Conv, Conv2, DWConv)) and hasattr(m, 'bn'):
  55. if isinstance(m, Conv2):
  56. m.fuse_convs()
  57. m.conv = fuse_conv_and_bn(m.conv, m.bn) # update conv
  58. delattr(m, 'bn') # remove batchnorm
  59. m.forward = m.forward_fuse # update forward
  60. if isinstance(m, ConvTranspose) and hasattr(m, 'bn'):
  61. m.conv_transpose = fuse_deconv_and_bn(m.conv_transpose, m.bn)
  62. delattr(m, 'bn') # remove batchnorm
  63. m.forward = m.forward_fuse # update forward
  64. if isinstance(m, RepConv):
  65. m.fuse_convs()
  66. m.forward = m.forward_fuse # update forward
  67. return self
  68. def _forward_once(self, x):
  69. """
  70. Perform a forward pass through the network.
  71. Args:
  72. x (torch.Tensor): The input tensor to the model
  73. Returns:
  74. (torch.Tensor): The last output of the model.
  75. """
  76. y, dt = [], [] # outputs
  77. for m in self.model:
  78. if m.f != -1: # if not from previous layer
  79. x = y[m.f] if isinstance(m.f, int) else [x if j == -1 else y[j] for j in m.f] # from earlier layers
  80. x = m(x) # run
  81. y.append(x if m.i in self.save else None) # save output
  82. return x
  83. def forward(self, x):
  84. """Run forward pass on input image(s) with optional augmentation and profiling."""
  85. return self._forward_once(x) # single-scale inference, train
  86. class Predictor:
  87. """
  88. Predictor
  89. A class for creating predictors.
  90. """
  91. def __init__(self, image_size, device, model):
  92. """
  93. Initializes the BasePredictor class.
  94. Args:
  95. cfg (str, optional): Path to a configuration file. Defaults to DEFAULT_CFG.
  96. overrides (dict, optional): Configuration overrides. Defaults to None.
  97. """
  98. self.iou = 0.7
  99. self.agnostic_nms = False
  100. self.max_det = 300
  101. self.filter_classes = None
  102. self.confidence = 0.25 # default conf=0.25
  103. # Usable if setup is done
  104. self.model = None
  105. self.imgsz = image_size
  106. self.device = device
  107. self.dataset = None
  108. self.stride = 32
  109. # 读取模型
  110. self.setup_model(model)
  111. log('setup model: yolo v8 once!')
  112. def preprocess(self, im):
  113. """Prepares input image before inference.
  114. Args:
  115. im (torch.Tensor | List(np.ndarray)): (N, 3, h, w) for tensor, [(h, w, 3) x N] for list.
  116. """
  117. im = np.stack(self.pre_transform(im))
  118. im = im[..., ::-1].transpose((0, 3, 1, 2)) # BGR to RGB, BHWC to BCHW, (n, 3, h, w)
  119. im = np.ascontiguousarray(im) # contiguous
  120. im = torch.from_numpy(im)
  121. # NOTE: assuming im with (b, 3, h, w) if it's a tensor
  122. img = im.to(self.device)
  123. img = img.float() # uint8 to fp16/32
  124. img /= 255 # 0 - 255 to 0.0 - 1.0
  125. return img
  126. def pre_transform(self, im):
  127. """Pre-tranform input image before inference.
  128. Args:
  129. im (List(np.ndarray)): (N, 3, h, w) for tensor, [(h, w, 3) x N] for list.
  130. Return: A list of transformed imgs.
  131. """
  132. same_shapes = all(x.shape == im[0].shape for x in im)
  133. auto = same_shapes
  134. img_list = [LetterBox(self.imgsz, auto=auto, stride=self.stride)(image=x) for x in im]
  135. # for img in img_list:
  136. # print('LetterBox img.shape', img.shape)
  137. # cv2.imshow('LetterBox', img)
  138. # cv2.waitKey(0)
  139. return img_list
  140. def postprocess(self, preds, img, orig_imgs):
  141. """Postprocesses predictions and returns a list of Results objects."""
  142. preds = non_max_suppression(preds,
  143. self.confidence,
  144. self.iou,
  145. agnostic=self.agnostic_nms,
  146. max_det=self.max_det,
  147. classes=self.filter_classes)
  148. results = []
  149. for i, pred in enumerate(preds):
  150. orig_img = orig_imgs[i] if isinstance(orig_imgs, list) else orig_imgs
  151. if not isinstance(orig_imgs, torch.Tensor):
  152. pred[:, :4] = scale_boxes(img.shape[2:], pred[:, :4], orig_img.shape)
  153. results.append(pred)
  154. return results
  155. def setup_source(self, source):
  156. """Sets up source and inference mode."""
  157. self.dataset = LoadPilAndNumpy(source, imgsz=self.imgsz)
  158. def setup_model(self, model):
  159. """Initialize YOLO model with given parameters and set it to evaluation mode."""
  160. self.model = attempt_load_one_weight(model,
  161. device=self.device,
  162. inplace=True)[0]
  163. self.model.float().eval()
  164. @smart_inference_mode()
  165. def stream_inference(self, source=None):
  166. """Streams real-time inference on camera feed and saves results to file."""
  167. # Setup model
  168. # if not self.model:
  169. # self.setup_model(model)
  170. # Setup source every time predict is called
  171. self.setup_source(source)
  172. results = []
  173. for batch in self.dataset:
  174. path, im0s, vid_cap, s = batch
  175. # print('im0s', im0s[0].shape)
  176. # _md5 = hashlib.md5(im0s[0])
  177. # print('md5', _md5.hexdigest())
  178. # cv2.imshow('im0s', im0s[0])
  179. # cv2.waitKey(0)
  180. # Preprocess
  181. im = self.preprocess(im0s)
  182. # print('im', im.shape)
  183. # Inference
  184. preds = self.model(im)
  185. # Postprocess
  186. result = self.postprocess(preds, im, im0s)
  187. results.append(result[0].tolist())
  188. print('stream_inference self.results', result[0].tolist())
  189. return results
  190. def predict(self, source=None, show=False):
  191. """Method used for CLI prediction. It uses always generator as outputs as not required by CLI mode."""
  192. # source = cv2.imread(source)
  193. results = self.stream_inference(source)
  194. if show:
  195. self.show(source, results[0])
  196. return results
  197. def show(self, source, result):
  198. for r in result:
  199. bbox = r[:4]
  200. bbox = [int(x) for x in bbox]
  201. confidence = r[4]
  202. cv2.rectangle(source, bbox[:2], bbox[2:4], color=(0, 0, 255), thickness=1)
  203. cv2.putText(source, str(round(confidence, 2)), (bbox[0], bbox[1]),
  204. cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 0, 0), 1)
  205. cv2.imshow('result', source)
  206. cv2.waitKey(0)