import hashlib import logging import os from copy import deepcopy import torch import numpy as np import cv2 import torch.nn as nn from format_convert.utils import log from botr.yolov8.module import Conv, Conv2, RepConv, ConvTranspose, DWConv, Detect, parse_model, fuse_conv_and_bn, \ fuse_deconv_and_bn from botr.yolov8.yolo_utils import yaml_load, initialize_weights, smart_inference_mode, \ attempt_load_one_weight, non_max_suppression, scale_boxes, LetterBox, LoadPilAndNumpy cfg_path = os.path.abspath(os.path.dirname(__file__)) + '/yolov8_model.yaml' class DetectionModel(nn.Module): """YOLOv8 detection model.""" def __init__(self, cfg=cfg_path, ch=3): super().__init__() self.yaml = yaml_load(cfg) # cfg dict # Define model self.model, self.save = parse_model(deepcopy(self.yaml), ch=ch) # model, savelist self.names = {i: f'{i}' for i in range(self.yaml['nc'])} # default names dict self.inplace = True # Build strides m = self.model[-1] # Detect() if isinstance(m, Detect): s = 256 # 2x min stride m.inplace = self.inplace forward = lambda x: self.forward(x) m.stride = torch.tensor([s / x.shape[-2] for x in forward(torch.zeros(1, ch, s, s))]) # forward self.stride = m.stride m.bias_init() # only run once # Init weights, biases initialize_weights(self) def is_fused(self, thresh=10): """ Check if the model has less than a certain threshold of BatchNorm layers. Args: thresh (int, optional): The threshold number of BatchNorm layers. Default is 10. Returns: (bool): True if the number of BatchNorm layers in the model is less than the threshold, False otherwise. """ bn = tuple(v for k, v in nn.__dict__.items() if 'Norm' in k) # normalization layers, i.e. BatchNorm2d() return sum(isinstance(v, bn) for v in self.modules()) < thresh # True if < 'thresh' BatchNorm layers in model def fuse(self): """ Fuse the `Conv2d()` and `BatchNorm2d()` layers of the model into a single layer, in order to improve the computation efficiency. Returns: (nn.Module): The fused model is returned. """ if not self.is_fused(): for m in self.model.modules(): if isinstance(m, (Conv, Conv2, DWConv)) and hasattr(m, 'bn'): if isinstance(m, Conv2): m.fuse_convs() m.conv = fuse_conv_and_bn(m.conv, m.bn) # update conv delattr(m, 'bn') # remove batchnorm m.forward = m.forward_fuse # update forward if isinstance(m, ConvTranspose) and hasattr(m, 'bn'): m.conv_transpose = fuse_deconv_and_bn(m.conv_transpose, m.bn) delattr(m, 'bn') # remove batchnorm m.forward = m.forward_fuse # update forward if isinstance(m, RepConv): m.fuse_convs() m.forward = m.forward_fuse # update forward return self def _forward_once(self, x): """ Perform a forward pass through the network. Args: x (torch.Tensor): The input tensor to the model Returns: (torch.Tensor): The last output of the model. """ y, dt = [], [] # outputs for m in self.model: if m.f != -1: # if not from previous layer x = y[m.f] if isinstance(m.f, int) else [x if j == -1 else y[j] for j in m.f] # from earlier layers x = m(x) # run y.append(x if m.i in self.save else None) # save output return x def forward(self, x): """Run forward pass on input image(s) with optional augmentation and profiling.""" return self._forward_once(x) # single-scale inference, train class Predictor: """ Predictor A class for creating predictors. """ def __init__(self, image_size, device, model): """ Initializes the BasePredictor class. Args: cfg (str, optional): Path to a configuration file. Defaults to DEFAULT_CFG. overrides (dict, optional): Configuration overrides. Defaults to None. """ self.iou = 0.7 self.agnostic_nms = False self.max_det = 300 self.filter_classes = None self.confidence = 0.25 # default conf=0.25 # Usable if setup is done self.model = None self.imgsz = image_size self.device = device self.dataset = None self.stride = 32 # 读取模型 self.setup_model(model) log('setup model: yolo v8 once!') def preprocess(self, im): """Prepares input image before inference. Args: im (torch.Tensor | List(np.ndarray)): (N, 3, h, w) for tensor, [(h, w, 3) x N] for list. """ im = np.stack(self.pre_transform(im)) im = im[..., ::-1].transpose((0, 3, 1, 2)) # BGR to RGB, BHWC to BCHW, (n, 3, h, w) im = np.ascontiguousarray(im) # contiguous im = torch.from_numpy(im) # NOTE: assuming im with (b, 3, h, w) if it's a tensor img = im.to(self.device) img = img.float() # uint8 to fp16/32 img /= 255 # 0 - 255 to 0.0 - 1.0 return img def pre_transform(self, im): """Pre-tranform input image before inference. Args: im (List(np.ndarray)): (N, 3, h, w) for tensor, [(h, w, 3) x N] for list. Return: A list of transformed imgs. """ same_shapes = all(x.shape == im[0].shape for x in im) auto = same_shapes img_list = [LetterBox(self.imgsz, auto=auto, stride=self.stride)(image=x) for x in im] # for img in img_list: # print('LetterBox img.shape', img.shape) # cv2.imshow('LetterBox', img) # cv2.waitKey(0) return img_list def postprocess(self, preds, img, orig_imgs): """Postprocesses predictions and returns a list of Results objects.""" preds = non_max_suppression(preds, self.confidence, self.iou, agnostic=self.agnostic_nms, max_det=self.max_det, classes=self.filter_classes) results = [] for i, pred in enumerate(preds): orig_img = orig_imgs[i] if isinstance(orig_imgs, list) else orig_imgs if not isinstance(orig_imgs, torch.Tensor): pred[:, :4] = scale_boxes(img.shape[2:], pred[:, :4], orig_img.shape) results.append(pred) return results def setup_source(self, source): """Sets up source and inference mode.""" self.dataset = LoadPilAndNumpy(source, imgsz=self.imgsz) def setup_model(self, model): """Initialize YOLO model with given parameters and set it to evaluation mode.""" self.model = attempt_load_one_weight(model, device=self.device, inplace=True)[0] self.model.float().eval() @smart_inference_mode() def stream_inference(self, source=None): """Streams real-time inference on camera feed and saves results to file.""" # Setup model # if not self.model: # self.setup_model(model) # Setup source every time predict is called self.setup_source(source) results = [] for batch in self.dataset: path, im0s, vid_cap, s = batch # print('im0s', im0s[0].shape) # _md5 = hashlib.md5(im0s[0]) # print('md5', _md5.hexdigest()) # cv2.imshow('im0s', im0s[0]) # cv2.waitKey(0) # Preprocess im = self.preprocess(im0s) # print('im', im.shape) # Inference preds = self.model(im) # Postprocess result = self.postprocess(preds, im, im0s) results.append(result[0].tolist()) print('stream_inference self.results', result[0].tolist()) return results def predict(self, source=None, show=False): """Method used for CLI prediction. It uses always generator as outputs as not required by CLI mode.""" # source = cv2.imread(source) results = self.stream_inference(source) if show: self.show(source, results[0]) return results def show(self, source, result): for r in result: bbox = r[:4] bbox = [int(x) for x in bbox] confidence = r[4] cv2.rectangle(source, bbox[:2], bbox[2:4], color=(0, 0, 255), thickness=1) cv2.putText(source, str(round(confidence, 2)), (bbox[0], bbox[1]), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 0, 0), 1) cv2.imshow('result', source) cv2.waitKey(0)