model.py 8.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237
  1. import logging
  2. import os
  3. from copy import deepcopy
  4. import torch
  5. import numpy as np
  6. import cv2
  7. import torch.nn as nn
  8. from format_convert.utils import log
  9. from botr.yolov8.module import Conv, Conv2, RepConv, ConvTranspose, DWConv, Detect, parse_model, fuse_conv_and_bn, \
  10. fuse_deconv_and_bn
  11. from botr.yolov8.yolo_utils import yaml_load, initialize_weights, smart_inference_mode, \
  12. attempt_load_one_weight, non_max_suppression, scale_boxes, LetterBox, LoadPilAndNumpy
  13. cfg_path = os.path.abspath(os.path.dirname(__file__)) + '/yolov8_model.yaml'
  14. class DetectionModel(nn.Module):
  15. """YOLOv8 detection model."""
  16. def __init__(self, cfg=cfg_path, ch=3):
  17. super().__init__()
  18. self.yaml = yaml_load(cfg) # cfg dict
  19. # Define model
  20. self.model, self.save = parse_model(deepcopy(self.yaml), ch=ch) # model, savelist
  21. self.names = {i: f'{i}' for i in range(self.yaml['nc'])} # default names dict
  22. self.inplace = True
  23. # Build strides
  24. m = self.model[-1] # Detect()
  25. if isinstance(m, Detect):
  26. s = 256 # 2x min stride
  27. m.inplace = self.inplace
  28. forward = lambda x: self.forward(x)
  29. m.stride = torch.tensor([s / x.shape[-2] for x in forward(torch.zeros(1, ch, s, s))]) # forward
  30. self.stride = m.stride
  31. m.bias_init() # only run once
  32. # Init weights, biases
  33. initialize_weights(self)
  34. def is_fused(self, thresh=10):
  35. """
  36. Check if the model has less than a certain threshold of BatchNorm layers.
  37. Args:
  38. thresh (int, optional): The threshold number of BatchNorm layers. Default is 10.
  39. Returns:
  40. (bool): True if the number of BatchNorm layers in the model is less than the threshold, False otherwise.
  41. """
  42. bn = tuple(v for k, v in nn.__dict__.items() if 'Norm' in k) # normalization layers, i.e. BatchNorm2d()
  43. return sum(isinstance(v, bn) for v in self.modules()) < thresh # True if < 'thresh' BatchNorm layers in model
  44. def fuse(self):
  45. """
  46. Fuse the `Conv2d()` and `BatchNorm2d()` layers of the model into a single layer, in order to improve the
  47. computation efficiency.
  48. Returns:
  49. (nn.Module): The fused model is returned.
  50. """
  51. if not self.is_fused():
  52. for m in self.model.modules():
  53. if isinstance(m, (Conv, Conv2, DWConv)) and hasattr(m, 'bn'):
  54. if isinstance(m, Conv2):
  55. m.fuse_convs()
  56. m.conv = fuse_conv_and_bn(m.conv, m.bn) # update conv
  57. delattr(m, 'bn') # remove batchnorm
  58. m.forward = m.forward_fuse # update forward
  59. if isinstance(m, ConvTranspose) and hasattr(m, 'bn'):
  60. m.conv_transpose = fuse_deconv_and_bn(m.conv_transpose, m.bn)
  61. delattr(m, 'bn') # remove batchnorm
  62. m.forward = m.forward_fuse # update forward
  63. if isinstance(m, RepConv):
  64. m.fuse_convs()
  65. m.forward = m.forward_fuse # update forward
  66. return self
  67. def _forward_once(self, x):
  68. """
  69. Perform a forward pass through the network.
  70. Args:
  71. x (torch.Tensor): The input tensor to the model
  72. Returns:
  73. (torch.Tensor): The last output of the model.
  74. """
  75. y, dt = [], [] # outputs
  76. for m in self.model:
  77. if m.f != -1: # if not from previous layer
  78. x = y[m.f] if isinstance(m.f, int) else [x if j == -1 else y[j] for j in m.f] # from earlier layers
  79. x = m(x) # run
  80. y.append(x if m.i in self.save else None) # save output
  81. return x
  82. def forward(self, x):
  83. """Run forward pass on input image(s) with optional augmentation and profiling."""
  84. return self._forward_once(x) # single-scale inference, train
  85. class Predictor:
  86. """
  87. Predictor
  88. A class for creating predictors.
  89. """
  90. def __init__(self, image_size, device, model):
  91. """
  92. Initializes the BasePredictor class.
  93. Args:
  94. cfg (str, optional): Path to a configuration file. Defaults to DEFAULT_CFG.
  95. overrides (dict, optional): Configuration overrides. Defaults to None.
  96. """
  97. self.iou = 0.7
  98. self.agnostic_nms = False
  99. self.max_det = 300
  100. self.filter_classes = None
  101. self.confidence = 0.25 # default conf=0.25
  102. # Usable if setup is done
  103. self.model = None
  104. self.imgsz = image_size
  105. self.device = device
  106. self.dataset = None
  107. self.stride = 32
  108. # 读取模型
  109. self.setup_model(model)
  110. log('setup model: yolo v8 once!')
  111. def preprocess(self, im):
  112. """Prepares input image before inference.
  113. Args:
  114. im (torch.Tensor | List(np.ndarray)): (N, 3, h, w) for tensor, [(h, w, 3) x N] for list.
  115. """
  116. im = np.stack(self.pre_transform(im))
  117. im = im[..., ::-1].transpose((0, 3, 1, 2)) # BGR to RGB, BHWC to BCHW, (n, 3, h, w)
  118. im = np.ascontiguousarray(im) # contiguous
  119. im = torch.from_numpy(im)
  120. # NOTE: assuming im with (b, 3, h, w) if it's a tensor
  121. img = im.to(self.device)
  122. img = img.float() # uint8 to fp16/32
  123. img /= 255 # 0 - 255 to 0.0 - 1.0
  124. return img
  125. def pre_transform(self, im):
  126. """Pre-tranform input image before inference.
  127. Args:
  128. im (List(np.ndarray)): (N, 3, h, w) for tensor, [(h, w, 3) x N] for list.
  129. Return: A list of transformed imgs.
  130. """
  131. same_shapes = all(x.shape == im[0].shape for x in im)
  132. auto = same_shapes
  133. return [LetterBox(self.imgsz, auto=auto, stride=self.stride)(image=x) for x in im]
  134. def postprocess(self, preds, img, orig_imgs):
  135. """Postprocesses predictions and returns a list of Results objects."""
  136. preds = non_max_suppression(preds,
  137. self.confidence,
  138. self.iou,
  139. agnostic=self.agnostic_nms,
  140. max_det=self.max_det,
  141. classes=self.filter_classes)
  142. results = []
  143. for i, pred in enumerate(preds):
  144. orig_img = orig_imgs[i] if isinstance(orig_imgs, list) else orig_imgs
  145. if not isinstance(orig_imgs, torch.Tensor):
  146. pred[:, :4] = scale_boxes(img.shape[2:], pred[:, :4], orig_img.shape)
  147. results.append(pred)
  148. return results
  149. def setup_source(self, source):
  150. """Sets up source and inference mode."""
  151. self.dataset = LoadPilAndNumpy(source, imgsz=self.imgsz)
  152. def setup_model(self, model):
  153. """Initialize YOLO model with given parameters and set it to evaluation mode."""
  154. self.model = attempt_load_one_weight(model,
  155. device=self.device,
  156. inplace=True)[0]
  157. self.model.float().eval()
  158. @smart_inference_mode()
  159. def stream_inference(self, source=None):
  160. """Streams real-time inference on camera feed and saves results to file."""
  161. # Setup model
  162. # if not self.model:
  163. # self.setup_model(model)
  164. # Setup source every time predict is called
  165. self.setup_source(source)
  166. results = []
  167. for batch in self.dataset:
  168. path, im0s, vid_cap, s = batch
  169. # Preprocess
  170. im = self.preprocess(im0s)
  171. # Inference
  172. preds = self.model(im)
  173. # Postprocess
  174. result = self.postprocess(preds, im, im0s)
  175. results.append(result[0].tolist())
  176. print('stream_inference self.results', result[0].tolist())
  177. return results
  178. def predict(self, source=None, show=False):
  179. """Method used for CLI prediction. It uses always generator as outputs as not required by CLI mode."""
  180. # source = cv2.imread(source)
  181. results = self.stream_inference(source)
  182. if show:
  183. self.show(source, results[0])
  184. return results
  185. def show(self, source, result):
  186. for r in result:
  187. bbox = r[:4]
  188. bbox = [int(x) for x in bbox]
  189. confidence = r[4]
  190. cv2.rectangle(source, bbox[:2], bbox[2:4], color=(0, 0, 255), thickness=1)
  191. cv2.putText(source, str(round(confidence, 2)), (bbox[0], bbox[1]),
  192. cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 0, 0), 1)
  193. cv2.imshow('result', source)
  194. cv2.waitKey(0)