yolo_utils.py 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466
  1. import logging
  2. import math
  3. import re
  4. import time
  5. from pathlib import Path
  6. from types import SimpleNamespace
  7. import cv2
  8. import torch
  9. import torchvision
  10. import yaml
  11. import numpy as np
  12. import torch.nn as nn
  13. from PIL import Image
  14. def yaml_load(file='data.yaml', append_filename=False):
  15. """
  16. Load YAML data from a file.
  17. Args:
  18. file (str, optional): File name. Default is 'data.yaml'.
  19. append_filename (bool): Add the YAML filename to the YAML dictionary. Default is False.
  20. Returns:
  21. dict: YAML data and file name.
  22. """
  23. with open(file, errors='ignore', encoding='utf-8') as f:
  24. s = f.read() # string
  25. # Remove special characters
  26. if not s.isprintable():
  27. s = re.sub(r'[^\x09\x0A\x0D\x20-\x7E\x85\xA0-\uD7FF\uE000-\uFFFD\U00010000-\U0010ffff]+', '', s)
  28. # Add YAML filename to dict and return
  29. return {**yaml.safe_load(s), 'yaml_file': str(file)} if append_filename else yaml.safe_load(s)
  30. def smart_inference_mode():
  31. """Applies torch.inference_mode() decorator if torch>=1.9.0 else torch.no_grad() decorator."""
  32. def decorate(fn):
  33. torch_version = re.findall('\d+', torch.__version__)
  34. if int(torch_version[0]) >= 1 and int(torch_version[1]) >= 9:
  35. TORCH_1_9 = True
  36. else:
  37. TORCH_1_9 = False
  38. """Applies appropriate torch decorator for inference mode based on torch version."""
  39. return (torch.inference_mode if TORCH_1_9 else torch.no_grad)()(fn)
  40. return decorate
  41. def make_anchors(feats, strides, grid_cell_offset=0.5):
  42. """Generate anchors from features."""
  43. anchor_points, stride_tensor = [], []
  44. assert feats is not None
  45. dtype, device = feats[0].dtype, feats[0].device
  46. for i, stride in enumerate(strides):
  47. _, _, h, w = feats[i].shape
  48. sx = torch.arange(end=w, device=device, dtype=dtype) + grid_cell_offset # shift x
  49. sy = torch.arange(end=h, device=device, dtype=dtype) + grid_cell_offset # shift y
  50. torch_version = re.findall('\d+', torch.__version__)
  51. if int(torch_version[0]) >= 1 and int(torch_version[1]) >= 10:
  52. TORCH_1_10 = True
  53. else:
  54. TORCH_1_10 = False
  55. sy, sx = torch.meshgrid(sy, sx, indexing='ij') if TORCH_1_10 else torch.meshgrid(sy, sx)
  56. anchor_points.append(torch.stack((sx, sy), -1).view(-1, 2))
  57. stride_tensor.append(torch.full((h * w, 1), stride, dtype=dtype, device=device))
  58. return torch.cat(anchor_points), torch.cat(stride_tensor)
  59. def dist2bbox(distance, anchor_points, xywh=True, dim=-1):
  60. """Transform distance(ltrb) to box(xywh or xyxy)."""
  61. lt, rb = distance.chunk(2, dim)
  62. x1y1 = anchor_points - lt
  63. x2y2 = anchor_points + rb
  64. if xywh:
  65. c_xy = (x1y1 + x2y2) / 2
  66. wh = x2y2 - x1y1
  67. return torch.cat((c_xy, wh), dim) # xywh bbox
  68. return torch.cat((x1y1, x2y2), dim) # xyxy bbox
  69. def attempt_load_one_weight(weight, device=None, inplace=True):
  70. """Loads a single model weights."""
  71. from botr.yolov8.module import Detect
  72. from botr.yolov8.model import DetectionModel
  73. model = DetectionModel()
  74. ckpt = model.load_state_dict(torch.load(weight))
  75. model.to(device).float()
  76. model = model.fuse().eval() # model in eval mode
  77. # Module compatibility updates
  78. for m in model.modules():
  79. t = type(m)
  80. if t in (nn.Hardswish, nn.LeakyReLU, nn.ReLU, nn.ReLU6, nn.SiLU, Detect):
  81. m.inplace = inplace # torch 1.7.0 compatibility
  82. elif t is nn.Upsample and not hasattr(m, 'recompute_scale_factor'):
  83. m.recompute_scale_factor = None # torch 1.11.0 compatibility
  84. # Return model and ckpt
  85. return model, ckpt
  86. def xywh2xyxy(x):
  87. """
  88. Convert bounding box coordinates from (x, y, width, height) format to (x1, y1, x2, y2) format where (x1, y1) is the
  89. top-left corner and (x2, y2) is the bottom-right corner.
  90. Args:
  91. x (np.ndarray) or (torch.Tensor): The input bounding box coordinates in (x, y, width, height) format.
  92. Returns:
  93. y (np.ndarray) or (torch.Tensor): The bounding box coordinates in (x1, y1, x2, y2) format.
  94. """
  95. y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
  96. y[..., 0] = x[..., 0] - x[..., 2] / 2 # top left x
  97. y[..., 1] = x[..., 1] - x[..., 3] / 2 # top left y
  98. y[..., 2] = x[..., 0] + x[..., 2] / 2 # bottom right x
  99. y[..., 3] = x[..., 1] + x[..., 3] / 2 # bottom right y
  100. return y
  101. def box_iou(box1, box2, eps=1e-7):
  102. """
  103. Calculate intersection-over-union (IoU) of boxes.
  104. Both sets of boxes are expected to be in (x1, y1, x2, y2) format.
  105. Based on https://github.com/pytorch/vision/blob/master/torchvision/ops/boxes.py
  106. Args:
  107. box1 (torch.Tensor): A tensor of shape (N, 4) representing N bounding boxes.
  108. box2 (torch.Tensor): A tensor of shape (M, 4) representing M bounding boxes.
  109. eps (float, optional): A small value to avoid division by zero. Defaults to 1e-7.
  110. Returns:
  111. (torch.Tensor): An NxM tensor containing the pairwise IoU values for every element in box1 and box2.
  112. """
  113. # inter(N,M) = (rb(N,M,2) - lt(N,M,2)).clamp(0).prod(2)
  114. (a1, a2), (b1, b2) = box1.unsqueeze(1).chunk(2, 2), box2.unsqueeze(0).chunk(2, 2)
  115. inter = (torch.min(a2, b2) - torch.max(a1, b1)).clamp_(0).prod(2)
  116. # IoU = inter / (area1 + area2 - inter)
  117. return inter / ((a2 - a1).prod(2) + (b2 - b1).prod(2) - inter + eps)
  118. def clip_boxes(boxes, shape):
  119. """
  120. It takes a list of bounding boxes and a shape (height, width) and clips the bounding boxes to the
  121. shape
  122. Args:
  123. boxes (torch.Tensor): the bounding boxes to clip
  124. shape (tuple): the shape of the image
  125. """
  126. if isinstance(boxes, torch.Tensor): # faster individually
  127. boxes[..., 0].clamp_(0, shape[1]) # x1
  128. boxes[..., 1].clamp_(0, shape[0]) # y1
  129. boxes[..., 2].clamp_(0, shape[1]) # x2
  130. boxes[..., 3].clamp_(0, shape[0]) # y2
  131. else: # np.array (faster grouped)
  132. boxes[..., [0, 2]] = boxes[..., [0, 2]].clip(0, shape[1]) # x1, x2
  133. boxes[..., [1, 3]] = boxes[..., [1, 3]].clip(0, shape[0]) # y1, y2
  134. def scale_boxes(img1_shape, boxes, img0_shape, ratio_pad=None):
  135. """
  136. Rescales bounding boxes (in the format of xyxy) from the shape of the image they were originally specified in
  137. (img1_shape) to the shape of a different image (img0_shape).
  138. Args:
  139. img1_shape (tuple): The shape of the image that the bounding boxes are for, in the format of (height, width).
  140. boxes (torch.Tensor): the bounding boxes of the objects in the image, in the format of (x1, y1, x2, y2)
  141. img0_shape (tuple): the shape of the target image, in the format of (height, width).
  142. ratio_pad (tuple): a tuple of (ratio, pad) for scaling the boxes. If not provided, the ratio and pad will be
  143. calculated based on the size difference between the two images.
  144. Returns:
  145. boxes (torch.Tensor): The scaled bounding boxes, in the format of (x1, y1, x2, y2)
  146. """
  147. if ratio_pad is None: # calculate from img0_shape
  148. gain = min(img1_shape[0] / img0_shape[0], img1_shape[1] / img0_shape[1]) # gain = old / new
  149. pad = round((img1_shape[1] - img0_shape[1] * gain) / 2 - 0.1), round(
  150. (img1_shape[0] - img0_shape[0] * gain) / 2 - 0.1) # wh padding
  151. else:
  152. gain = ratio_pad[0][0]
  153. pad = ratio_pad[1]
  154. boxes[..., [0, 2]] -= pad[0] # x padding
  155. boxes[..., [1, 3]] -= pad[1] # y padding
  156. boxes[..., :4] /= gain
  157. clip_boxes(boxes, img0_shape)
  158. return boxes
  159. def non_max_suppression(
  160. prediction,
  161. conf_thres=0.25,
  162. iou_thres=0.45,
  163. classes=None,
  164. agnostic=False,
  165. multi_label=False,
  166. labels=(),
  167. max_det=300,
  168. nc=0, # number of classes (optional)
  169. max_time_img=0.05,
  170. max_nms=30000,
  171. max_wh=7680,
  172. ):
  173. """
  174. Perform non-maximum suppression (NMS) on a set of boxes, with support for masks and multiple labels per box.
  175. Arguments:
  176. prediction (torch.Tensor): A tensor of shape (batch_size, num_classes + 4 + num_masks, num_boxes)
  177. containing the predicted boxes, classes, and masks. The tensor should be in the format
  178. output by a model, such as YOLO.
  179. conf_thres (float): The confidence threshold below which boxes will be filtered out.
  180. Valid values are between 0.0 and 1.0.
  181. iou_thres (float): The IoU threshold below which boxes will be filtered out during NMS.
  182. Valid values are between 0.0 and 1.0.
  183. classes (List[int]): A list of class indices to consider. If None, all classes will be considered.
  184. agnostic (bool): If True, the model is agnostic to the number of classes, and all
  185. classes will be considered as one.
  186. multi_label (bool): If True, each box may have multiple labels.
  187. labels (List[List[Union[int, float, torch.Tensor]]]): A list of lists, where each inner
  188. list contains the apriori labels for a given image. The list should be in the format
  189. output by a dataloader, with each label being a tuple of (class_index, x1, y1, x2, y2).
  190. max_det (int): The maximum number of boxes to keep after NMS.
  191. nc (int): (optional) The number of classes output by the model. Any indices after this will be considered masks.
  192. max_time_img (float): The maximum time (seconds) for processing one image.
  193. max_nms (int): The maximum number of boxes into torchvision.ops.nms().
  194. max_wh (int): The maximum box width and height in pixels
  195. Returns:
  196. (List[torch.Tensor]): A list of length batch_size, where each element is a tensor of
  197. shape (num_boxes, 6 + num_masks) containing the kept boxes, with columns
  198. (x1, y1, x2, y2, confidence, class, mask1, mask2, ...).
  199. """
  200. # Checks
  201. assert 0 <= conf_thres <= 1, f'Invalid Confidence threshold {conf_thres}, valid values are between 0.0 and 1.0'
  202. assert 0 <= iou_thres <= 1, f'Invalid IoU {iou_thres}, valid values are between 0.0 and 1.0'
  203. if isinstance(prediction, (list, tuple)): # YOLOv8 model in validation model, output = (inference_out, loss_out)
  204. prediction = prediction[0] # select only inference output
  205. device = prediction.device
  206. mps = 'mps' in device.type # Apple MPS
  207. if mps: # MPS not fully supported yet, convert tensors to CPU before NMS
  208. prediction = prediction.cpu()
  209. bs = prediction.shape[0] # batch size
  210. nc = nc or (prediction.shape[1] - 4) # number of classes
  211. nm = prediction.shape[1] - nc - 4
  212. mi = 4 + nc # mask start index
  213. xc = prediction[:, 4:mi].amax(1) > conf_thres # candidates
  214. # Settings
  215. # min_wh = 2 # (pixels) minimum box width and height
  216. time_limit = 0.5 + max_time_img * bs # seconds to quit after
  217. redundant = True # require redundant detections
  218. multi_label &= nc > 1 # multiple labels per box (adds 0.5ms/img)
  219. merge = False # use merge-NMS
  220. t = time.time()
  221. output = [torch.zeros((0, 6 + nm), device=prediction.device)] * bs
  222. for xi, x in enumerate(prediction): # image index, image inference
  223. # Apply constraints
  224. # x[((x[:, 2:4] < min_wh) | (x[:, 2:4] > max_wh)).any(1), 4] = 0 # width-height
  225. x = x.transpose(0, -1)[xc[xi]] # confidence
  226. # Cat apriori labels if autolabelling
  227. if labels and len(labels[xi]):
  228. lb = labels[xi]
  229. v = torch.zeros((len(lb), nc + nm + 5), device=x.device)
  230. v[:, :4] = lb[:, 1:5] # box
  231. v[range(len(lb)), lb[:, 0].long() + 4] = 1.0 # cls
  232. x = torch.cat((x, v), 0)
  233. # If none remain process next image
  234. if not x.shape[0]:
  235. continue
  236. # Detections matrix nx6 (xyxy, conf, cls)
  237. box, cls, mask = x.split((4, nc, nm), 1)
  238. box = xywh2xyxy(box) # center_x, center_y, width, height) to (x1, y1, x2, y2)
  239. if multi_label:
  240. i, j = (cls > conf_thres).nonzero(as_tuple=False).T
  241. x = torch.cat((box[i], x[i, 4 + j, None], j[:, None].float(), mask[i]), 1)
  242. else: # best class only
  243. conf, j = cls.max(1, keepdim=True)
  244. x = torch.cat((box, conf, j.float(), mask), 1)[conf.view(-1) > conf_thres]
  245. # Filter by class
  246. if classes is not None:
  247. x = x[(x[:, 5:6] == torch.tensor(classes, device=x.device)).any(1)]
  248. # Apply finite constraint
  249. # if not torch.isfinite(x).all():
  250. # x = x[torch.isfinite(x).all(1)]
  251. # Check shape
  252. n = x.shape[0] # number of boxes
  253. if not n: # no boxes
  254. continue
  255. x = x[x[:, 4].argsort(descending=True)[:max_nms]] # sort by confidence and remove excess boxes
  256. # Batched NMS
  257. c = x[:, 5:6] * (0 if agnostic else max_wh) # classes
  258. boxes, scores = x[:, :4] + c, x[:, 4] # boxes (offset by class), scores
  259. i = torchvision.ops.nms(boxes, scores, iou_thres) # NMS
  260. i = i[:max_det] # limit detections
  261. if merge and (1 < n < 3E3): # Merge NMS (boxes merged using weighted mean)
  262. # Update boxes as boxes(i,4) = weights(i,n) * boxes(n,4)
  263. iou = box_iou(boxes[i], boxes) > iou_thres # iou matrix
  264. weights = iou * scores[None] # box weights
  265. x[i, :4] = torch.mm(weights, x[:, :4]).float() / weights.sum(1, keepdim=True) # merged boxes
  266. if redundant:
  267. i = i[iou.sum(1) > 1] # require redundancy
  268. output[xi] = x[i]
  269. if mps:
  270. output[xi] = output[xi].to(device)
  271. if (time.time() - t) > time_limit:
  272. logging.warning(f'WARNING ⚠️ NMS time limit {time_limit:.3f}s exceeded')
  273. break # time limit exceeded
  274. return output
  275. def make_divisible(x, divisor):
  276. """Returns nearest x divisible by divisor."""
  277. if isinstance(divisor, torch.Tensor):
  278. divisor = int(divisor.max()) # to int
  279. return math.ceil(x / divisor) * divisor
  280. def initialize_weights(model):
  281. """Initialize model weights to random values."""
  282. for m in model.modules():
  283. t = type(m)
  284. if t is nn.Conv2d:
  285. pass # nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
  286. elif t is nn.BatchNorm2d:
  287. m.eps = 1e-3
  288. m.momentum = 0.03
  289. elif t in [nn.Hardswish, nn.LeakyReLU, nn.ReLU, nn.ReLU6, nn.SiLU]:
  290. m.inplace = True
  291. def get_num_params(model):
  292. """Return the total number of parameters in a YOLO model."""
  293. return sum(x.numel() for x in model.parameters())
  294. def get_num_gradients(model):
  295. """Return the total number of parameters with gradients in a YOLO model."""
  296. return sum(x.numel() for x in model.parameters() if x.requires_grad)
  297. class LetterBox:
  298. """Resize image and padding for detection, instance segmentation, pose."""
  299. def __init__(self, new_shape=(640, 640), auto=False, scaleFill=False, scaleup=True, stride=32):
  300. """Initialize LetterBox object with specific parameters."""
  301. self.new_shape = new_shape
  302. self.auto = auto
  303. self.scaleFill = scaleFill
  304. self.scaleup = scaleup
  305. self.stride = stride
  306. def __call__(self, labels=None, image=None):
  307. """Return updated labels and image with added border."""
  308. if labels is None:
  309. labels = {}
  310. img = labels.get('img') if image is None else image
  311. shape = img.shape[:2] # current shape [height, width]
  312. new_shape = labels.pop('rect_shape', self.new_shape)
  313. if isinstance(new_shape, int):
  314. new_shape = (new_shape, new_shape)
  315. # Scale ratio (new / old)
  316. r = min(new_shape[0] / shape[0], new_shape[1] / shape[1])
  317. if not self.scaleup: # only scale down, do not scale up (for better val mAP)
  318. r = min(r, 1.0)
  319. # Compute padding
  320. ratio = r, r # width, height ratios
  321. new_unpad = int(round(shape[1] * r)), int(round(shape[0] * r))
  322. dw, dh = new_shape[1] - new_unpad[0], new_shape[0] - new_unpad[1] # wh padding
  323. if self.auto: # minimum rectangle
  324. dw, dh = np.mod(dw, self.stride), np.mod(dh, self.stride) # wh padding
  325. elif self.scaleFill: # stretch
  326. dw, dh = 0.0, 0.0
  327. new_unpad = (new_shape[1], new_shape[0])
  328. ratio = new_shape[1] / shape[1], new_shape[0] / shape[0] # width, height ratios
  329. dw /= 2 # divide padding into 2 sides
  330. dh /= 2
  331. if labels.get('ratio_pad'):
  332. labels['ratio_pad'] = (labels['ratio_pad'], (dw, dh)) # for evaluation
  333. if shape[::-1] != new_unpad: # resize
  334. img = cv2.resize(img, new_unpad, interpolation=cv2.INTER_LINEAR)
  335. top, bottom = int(round(dh - 0.1)), int(round(dh + 0.1))
  336. left, right = int(round(dw - 0.1)), int(round(dw + 0.1))
  337. img = cv2.copyMakeBorder(img, top, bottom, left, right, cv2.BORDER_CONSTANT,
  338. value=(114, 114, 114)) # add border
  339. if len(labels):
  340. labels = self._update_labels(labels, ratio, dw, dh)
  341. labels['img'] = img
  342. labels['resized_shape'] = new_shape
  343. return labels
  344. else:
  345. return img
  346. def _update_labels(self, labels, ratio, padw, padh):
  347. """Update labels."""
  348. labels['instances'].convert_bbox(format='xyxy')
  349. labels['instances'].denormalize(*labels['img'].shape[:2][::-1])
  350. labels['instances'].scale(*ratio)
  351. labels['instances'].add_padding(padw, padh)
  352. return labels
  353. class LoadPilAndNumpy:
  354. def __init__(self, im0, imgsz=640):
  355. """Initialize PIL and Numpy Dataloader."""
  356. if not isinstance(im0, list):
  357. im0 = [im0]
  358. self.paths = [getattr(im, 'filename', f'image{i}.jpg') for i, im in enumerate(im0)]
  359. self.im0 = [self._single_check(im) for im in im0]
  360. self.imgsz = imgsz
  361. self.mode = 'image'
  362. # Generate fake paths
  363. self.bs = len(self.im0)
  364. self.source_type = ''
  365. @staticmethod
  366. def _single_check(im):
  367. """Validate and format an image to numpy array."""
  368. assert isinstance(im, (Image.Image, np.ndarray)), f'Expected PIL/np.ndarray image type, but got {type(im)}'
  369. if isinstance(im, Image.Image):
  370. if im.mode != 'RGB':
  371. im = im.convert('RGB')
  372. im = np.asarray(im)[:, :, ::-1]
  373. im = np.ascontiguousarray(im) # contiguous
  374. return im
  375. def __len__(self):
  376. """Returns the length of the 'im0' attribute."""
  377. return len(self.im0)
  378. def __next__(self):
  379. """Returns batch paths, images, processed images, None, ''."""
  380. if self.count == 1: # loop only once as it's batch inference
  381. raise StopIteration
  382. self.count += 1
  383. return self.paths, self.im0, None, ''
  384. def __iter__(self):
  385. """Enables iteration for class LoadPilAndNumpy."""
  386. self.count = 0
  387. return self