module.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436
  1. import contextlib
  2. import logging
  3. import math
  4. import numpy as np
  5. import torch
  6. import torch.nn as nn
  7. from botr.yolov8.yolo_utils import make_anchors, dist2bbox, make_divisible
  8. def autopad(k, p=None, d=1): # kernel, padding, dilation
  9. """Pad to 'same' shape outputs."""
  10. if d > 1:
  11. k = d * (k - 1) + 1 if isinstance(k, int) else [d * (x - 1) + 1 for x in k] # actual kernel-size
  12. if p is None:
  13. p = k // 2 if isinstance(k, int) else [x // 2 for x in k] # auto-pad
  14. return p
  15. class Conv(nn.Module):
  16. """Standard convolution with args(ch_in, ch_out, kernel, stride, padding, groups, dilation, activation)."""
  17. default_act = nn.SiLU() # default activation
  18. def __init__(self, c1, c2, k=1, s=1, p=None, g=1, d=1, act=True):
  19. """Initialize Conv layer with given arguments including activation."""
  20. super().__init__()
  21. self.conv = nn.Conv2d(c1, c2, k, s, autopad(k, p, d), groups=g, dilation=d, bias=False)
  22. self.bn = nn.BatchNorm2d(c2)
  23. self.act = self.default_act if act is True else act if isinstance(act, nn.Module) else nn.Identity()
  24. def forward(self, x):
  25. """Apply convolution, batch normalization and activation to input tensor."""
  26. return self.act(self.bn(self.conv(x)))
  27. def forward_fuse(self, x):
  28. """Perform transposed convolution of 2D data."""
  29. return self.act(self.conv(x))
  30. class Conv2(Conv):
  31. """Simplified RepConv module with Conv fusing."""
  32. def __init__(self, c1, c2, k=3, s=1, p=None, g=1, d=1, act=True):
  33. """Initialize Conv layer with given arguments including activation."""
  34. super().__init__(c1, c2, k, s, p, g=g, d=d, act=act)
  35. self.cv2 = nn.Conv2d(c1, c2, 1, s, autopad(1, p, d), groups=g, dilation=d, bias=False) # add 1x1 conv
  36. def forward(self, x):
  37. """Apply convolution, batch normalization and activation to input tensor."""
  38. return self.act(self.bn(self.conv(x) + self.cv2(x)))
  39. def fuse_convs(self):
  40. """Fuse parallel convolutions."""
  41. w = torch.zeros_like(self.conv.weight.data)
  42. i = [x // 2 for x in w.shape[2:]]
  43. w[:, :, i[0]:i[0] + 1, i[1]:i[1] + 1] = self.cv2.weight.data.clone()
  44. self.conv.weight.data += w
  45. self.__delattr__('cv2')
  46. class DWConv(Conv):
  47. """Depth-wise convolution."""
  48. def __init__(self, c1, c2, k=1, s=1, d=1, act=True): # ch_in, ch_out, kernel, stride, dilation, activation
  49. super().__init__(c1, c2, k, s, g=math.gcd(c1, c2), d=d, act=act)
  50. class ConvTranspose(nn.Module):
  51. """Convolution transpose 2d layer."""
  52. default_act = nn.SiLU() # default activation
  53. def __init__(self, c1, c2, k=2, s=2, p=0, bn=True, act=True):
  54. """Initialize ConvTranspose2d layer with batch normalization and activation function."""
  55. super().__init__()
  56. self.conv_transpose = nn.ConvTranspose2d(c1, c2, k, s, p, bias=not bn)
  57. self.bn = nn.BatchNorm2d(c2) if bn else nn.Identity()
  58. self.act = self.default_act if act is True else act if isinstance(act, nn.Module) else nn.Identity()
  59. def forward(self, x):
  60. """Applies transposed convolutions, batch normalization and activation to input."""
  61. return self.act(self.bn(self.conv_transpose(x)))
  62. def forward_fuse(self, x):
  63. """Applies activation and convolution transpose operation to input."""
  64. return self.act(self.conv_transpose(x))
  65. class RepConv(nn.Module):
  66. """RepConv is a basic rep-style block, including training and deploy status
  67. This code is based on https://github.com/DingXiaoH/RepVGG/blob/main/repvgg.py
  68. """
  69. default_act = nn.SiLU() # default activation
  70. def __init__(self, c1, c2, k=3, s=1, p=1, g=1, d=1, act=True, bn=False, deploy=False):
  71. super().__init__()
  72. assert k == 3 and p == 1
  73. self.g = g
  74. self.c1 = c1
  75. self.c2 = c2
  76. self.act = self.default_act if act is True else act if isinstance(act, nn.Module) else nn.Identity()
  77. self.bn = nn.BatchNorm2d(num_features=c1) if bn and c2 == c1 and s == 1 else None
  78. self.conv1 = Conv(c1, c2, k, s, p=p, g=g, act=False)
  79. self.conv2 = Conv(c1, c2, 1, s, p=(p - k // 2), g=g, act=False)
  80. def forward_fuse(self, x):
  81. """Forward process"""
  82. return self.act(self.conv(x))
  83. def forward(self, x):
  84. """Forward process"""
  85. id_out = 0 if self.bn is None else self.bn(x)
  86. return self.act(self.conv1(x) + self.conv2(x) + id_out)
  87. def get_equivalent_kernel_bias(self):
  88. kernel3x3, bias3x3 = self._fuse_bn_tensor(self.conv1)
  89. kernel1x1, bias1x1 = self._fuse_bn_tensor(self.conv2)
  90. kernelid, biasid = self._fuse_bn_tensor(self.bn)
  91. return kernel3x3 + self._pad_1x1_to_3x3_tensor(kernel1x1) + kernelid, bias3x3 + bias1x1 + biasid
  92. def _avg_to_3x3_tensor(self, avgp):
  93. channels = self.c1
  94. groups = self.g
  95. kernel_size = avgp.kernel_size
  96. input_dim = channels // groups
  97. k = torch.zeros((channels, input_dim, kernel_size, kernel_size))
  98. k[np.arange(channels), np.tile(np.arange(input_dim), groups), :, :] = 1.0 / kernel_size ** 2
  99. return k
  100. def _pad_1x1_to_3x3_tensor(self, kernel1x1):
  101. if kernel1x1 is None:
  102. return 0
  103. else:
  104. return torch.nn.functional.pad(kernel1x1, [1, 1, 1, 1])
  105. def _fuse_bn_tensor(self, branch):
  106. if branch is None:
  107. return 0, 0
  108. if isinstance(branch, Conv):
  109. kernel = branch.conv.weight
  110. running_mean = branch.bn.running_mean
  111. running_var = branch.bn.running_var
  112. gamma = branch.bn.weight
  113. beta = branch.bn.bias
  114. eps = branch.bn.eps
  115. elif isinstance(branch, nn.BatchNorm2d):
  116. if not hasattr(self, 'id_tensor'):
  117. input_dim = self.c1 // self.g
  118. kernel_value = np.zeros((self.c1, input_dim, 3, 3), dtype=np.float32)
  119. for i in range(self.c1):
  120. kernel_value[i, i % input_dim, 1, 1] = 1
  121. self.id_tensor = torch.from_numpy(kernel_value).to(branch.weight.device)
  122. kernel = self.id_tensor
  123. running_mean = branch.running_mean
  124. running_var = branch.running_var
  125. gamma = branch.weight
  126. beta = branch.bias
  127. eps = branch.eps
  128. std = (running_var + eps).sqrt()
  129. t = (gamma / std).reshape(-1, 1, 1, 1)
  130. return kernel * t, beta - running_mean * gamma / std
  131. def fuse_convs(self):
  132. if hasattr(self, 'conv'):
  133. return
  134. kernel, bias = self.get_equivalent_kernel_bias()
  135. self.conv = nn.Conv2d(in_channels=self.conv1.conv.in_channels,
  136. out_channels=self.conv1.conv.out_channels,
  137. kernel_size=self.conv1.conv.kernel_size,
  138. stride=self.conv1.conv.stride,
  139. padding=self.conv1.conv.padding,
  140. dilation=self.conv1.conv.dilation,
  141. groups=self.conv1.conv.groups,
  142. bias=True).requires_grad_(False)
  143. self.conv.weight.data = kernel
  144. self.conv.bias.data = bias
  145. for para in self.parameters():
  146. para.detach_()
  147. self.__delattr__('conv1')
  148. self.__delattr__('conv2')
  149. if hasattr(self, 'nm'):
  150. self.__delattr__('nm')
  151. if hasattr(self, 'bn'):
  152. self.__delattr__('bn')
  153. if hasattr(self, 'id_tensor'):
  154. self.__delattr__('id_tensor')
  155. class DFL(nn.Module):
  156. """
  157. Integral module of Distribution Focal Loss (DFL).
  158. Proposed in Generalized Focal Loss https://ieeexplore.ieee.org/document/9792391
  159. """
  160. def __init__(self, c1=16):
  161. """Initialize a convolutional layer with a given number of input channels."""
  162. super().__init__()
  163. self.conv = nn.Conv2d(c1, 1, 1, bias=False).requires_grad_(False)
  164. x = torch.arange(c1, dtype=torch.float)
  165. self.conv.weight.data[:] = nn.Parameter(x.view(1, c1, 1, 1))
  166. self.c1 = c1
  167. def forward(self, x):
  168. """Applies a transformer layer on input tensor 'x' and returns a tensor."""
  169. b, c, a = x.shape # batch, channels, anchors
  170. return self.conv(x.view(b, 4, self.c1, a).transpose(2, 1).softmax(1)).view(b, 4, a)
  171. # return self.conv(x.view(b, self.c1, 4, a).softmax(1)).view(b, 4, a)
  172. class Concat(nn.Module):
  173. """Concatenate a list of tensors along dimension."""
  174. def __init__(self, dimension=1):
  175. """Concatenates a list of tensors along a specified dimension."""
  176. super().__init__()
  177. self.d = dimension
  178. def forward(self, x):
  179. """Forward pass for the YOLOv8 mask Proto module."""
  180. return torch.cat(x, self.d)
  181. class SPPF(nn.Module):
  182. """Spatial Pyramid Pooling - Fast (SPPF) layer for YOLOv5 by Glenn Jocher."""
  183. def __init__(self, c1, c2, k=5): # equivalent to SPP(k=(5, 9, 13))
  184. super().__init__()
  185. c_ = c1 // 2 # hidden channels
  186. self.cv1 = Conv(c1, c_, 1, 1)
  187. self.cv2 = Conv(c_ * 4, c2, 1, 1)
  188. self.m = nn.MaxPool2d(kernel_size=k, stride=1, padding=k // 2)
  189. def forward(self, x):
  190. """Forward pass through Ghost Convolution block."""
  191. x = self.cv1(x)
  192. y1 = self.m(x)
  193. y2 = self.m(y1)
  194. return self.cv2(torch.cat((x, y1, y2, self.m(y2)), 1))
  195. class Bottleneck(nn.Module):
  196. """Standard bottleneck."""
  197. def __init__(self, c1, c2, shortcut=True, g=1, k=(3, 3), e=0.5): # ch_in, ch_out, shortcut, groups, kernels, expand
  198. super().__init__()
  199. c_ = int(c2 * e) # hidden channels
  200. self.cv1 = Conv(c1, c_, k[0], 1)
  201. self.cv2 = Conv(c_, c2, k[1], 1, g=g)
  202. self.add = shortcut and c1 == c2
  203. def forward(self, x):
  204. """'forward()' applies the YOLOv5 FPN to input data."""
  205. return x + self.cv2(self.cv1(x)) if self.add else self.cv2(self.cv1(x))
  206. class C2f(nn.Module):
  207. """CSP Bottleneck with 2 convolutions."""
  208. def __init__(self, c1, c2, n=1, shortcut=False, g=1, e=0.5): # ch_in, ch_out, number, shortcut, groups, expansion
  209. super().__init__()
  210. self.c = int(c2 * e) # hidden channels
  211. self.cv1 = Conv(c1, 2 * self.c, 1, 1)
  212. self.cv2 = Conv((2 + n) * self.c, c2, 1) # optional act=FReLU(c2)
  213. self.m = nn.ModuleList(Bottleneck(self.c, self.c, shortcut, g, k=((3, 3), (3, 3)), e=1.0) for _ in range(n))
  214. def forward(self, x):
  215. """Forward pass through C2f layer."""
  216. y = list(self.cv1(x).chunk(2, 1))
  217. y.extend(m(y[-1]) for m in self.m)
  218. return self.cv2(torch.cat(y, 1))
  219. def forward_split(self, x):
  220. """Forward pass using split() instead of chunk()."""
  221. y = list(self.cv1(x).split((self.c, self.c), 1))
  222. y.extend(m(y[-1]) for m in self.m)
  223. return self.cv2(torch.cat(y, 1))
  224. class Detect(nn.Module):
  225. """YOLOv8 Detect head for detection models."""
  226. dynamic = False # force grid reconstruction
  227. export = False # export mode
  228. shape = None
  229. anchors = torch.empty(0) # init
  230. strides = torch.empty(0) # init
  231. def __init__(self, nc=80, ch=()): # detection layer
  232. super().__init__()
  233. self.nc = nc # number of classes
  234. self.nl = len(ch) # number of detection layers
  235. self.reg_max = 16 # DFL channels (ch[0] // 16 to scale 4/8/12/16/20 for n/s/m/l/x)
  236. self.no = nc + self.reg_max * 4 # number of outputs per anchor
  237. self.stride = torch.zeros(self.nl) # strides computed during build
  238. c2, c3 = max((16, ch[0] // 4, self.reg_max * 4)), max(ch[0], self.nc) # channels
  239. self.cv2 = nn.ModuleList(
  240. nn.Sequential(Conv(x, c2, 3), Conv(c2, c2, 3), nn.Conv2d(c2, 4 * self.reg_max, 1)) for x in ch)
  241. self.cv3 = nn.ModuleList(nn.Sequential(Conv(x, c3, 3), Conv(c3, c3, 3), nn.Conv2d(c3, self.nc, 1)) for x in ch)
  242. self.dfl = DFL(self.reg_max) if self.reg_max > 1 else nn.Identity()
  243. def forward(self, x):
  244. """Concatenates and returns predicted bounding boxes and class probabilities."""
  245. shape = x[0].shape # BCHW
  246. for i in range(self.nl):
  247. x[i] = torch.cat((self.cv2[i](x[i]), self.cv3[i](x[i])), 1)
  248. if self.training:
  249. return x
  250. elif self.dynamic or self.shape != shape:
  251. self.anchors, self.strides = (x.transpose(0, 1) for x in make_anchors(x, self.stride, 0.5))
  252. self.shape = shape
  253. x_cat = torch.cat([xi.view(shape[0], self.no, -1) for xi in x], 2)
  254. if self.export and self.format in ('saved_model', 'pb', 'tflite', 'edgetpu', 'tfjs'): # avoid TF FlexSplitV ops
  255. box = x_cat[:, :self.reg_max * 4]
  256. cls = x_cat[:, self.reg_max * 4:]
  257. else:
  258. box, cls = x_cat.split((self.reg_max * 4, self.nc), 1)
  259. dbox = dist2bbox(self.dfl(box), self.anchors.unsqueeze(0), xywh=True, dim=1) * self.strides
  260. y = torch.cat((dbox, cls.sigmoid()), 1)
  261. return y if self.export else (y, x)
  262. def bias_init(self):
  263. """Initialize Detect() biases, WARNING: requires stride availability."""
  264. m = self # self.model[-1] # Detect() module
  265. # cf = torch.bincount(torch.tensor(np.concatenate(dataset.labels, 0)[:, 0]).long(), minlength=nc) + 1
  266. # ncf = math.log(0.6 / (m.nc - 0.999999)) if cf is None else torch.log(cf / cf.sum()) # nominal class frequency
  267. for a, b, s in zip(m.cv2, m.cv3, m.stride): # from
  268. a[-1].bias.data[:] = 1.0 # box
  269. b[-1].bias.data[:m.nc] = math.log(5 / m.nc / (640 / s) ** 2) # cls (.01 objects, 80 classes, 640 img)
  270. def fuse_conv_and_bn(conv, bn):
  271. """Fuse Conv2d() and BatchNorm2d() layers https://tehnokv.com/posts/fusing-batchnorm-and-conv/."""
  272. fusedconv = nn.Conv2d(conv.in_channels,
  273. conv.out_channels,
  274. kernel_size=conv.kernel_size,
  275. stride=conv.stride,
  276. padding=conv.padding,
  277. dilation=conv.dilation,
  278. groups=conv.groups,
  279. bias=True).requires_grad_(False).to(conv.weight.device)
  280. # Prepare filters
  281. w_conv = conv.weight.clone().view(conv.out_channels, -1)
  282. w_bn = torch.diag(bn.weight.div(torch.sqrt(bn.eps + bn.running_var)))
  283. fusedconv.weight.copy_(torch.mm(w_bn, w_conv).view(fusedconv.weight.shape))
  284. # Prepare spatial bias
  285. b_conv = torch.zeros(conv.weight.size(0), device=conv.weight.device) if conv.bias is None else conv.bias
  286. b_bn = bn.bias - bn.weight.mul(bn.running_mean).div(torch.sqrt(bn.running_var + bn.eps))
  287. fusedconv.bias.copy_(torch.mm(w_bn, b_conv.reshape(-1, 1)).reshape(-1) + b_bn)
  288. return fusedconv
  289. def fuse_deconv_and_bn(deconv, bn):
  290. """Fuse ConvTranspose2d() and BatchNorm2d() layers."""
  291. fuseddconv = nn.ConvTranspose2d(deconv.in_channels,
  292. deconv.out_channels,
  293. kernel_size=deconv.kernel_size,
  294. stride=deconv.stride,
  295. padding=deconv.padding,
  296. output_padding=deconv.output_padding,
  297. dilation=deconv.dilation,
  298. groups=deconv.groups,
  299. bias=True).requires_grad_(False).to(deconv.weight.device)
  300. # Prepare filters
  301. w_deconv = deconv.weight.clone().view(deconv.out_channels, -1)
  302. w_bn = torch.diag(bn.weight.div(torch.sqrt(bn.eps + bn.running_var)))
  303. fuseddconv.weight.copy_(torch.mm(w_bn, w_deconv).view(fuseddconv.weight.shape))
  304. # Prepare spatial bias
  305. b_conv = torch.zeros(deconv.weight.size(1), device=deconv.weight.device) if deconv.bias is None else deconv.bias
  306. b_bn = bn.bias - bn.weight.mul(bn.running_mean).div(torch.sqrt(bn.running_var + bn.eps))
  307. fuseddconv.bias.copy_(torch.mm(w_bn, b_conv.reshape(-1, 1)).reshape(-1) + b_bn)
  308. return fuseddconv
  309. def parse_model(d, ch):
  310. # Parse a YOLO model.yaml dictionary into a PyTorch model
  311. import ast
  312. # Args
  313. max_channels = float('inf')
  314. nc, act, scales = (d.get(x) for x in ('nc', 'act', 'scales'))
  315. depth, width, kpt_shape = (d.get(x, 1.0) for x in ('depth_multiple', 'width_multiple', 'kpt_shape'))
  316. if scales:
  317. scale = d.get('scale')
  318. if not scale:
  319. scale = tuple(scales.keys())[0]
  320. logging.warning(f"WARNING ⚠️ no model scale passed. Assuming scale='{scale}'.")
  321. depth, width, max_channels = scales[scale]
  322. if act:
  323. Conv.default_act = eval(act) # redefine default activation, i.e. Conv.default_act = nn.SiLU()
  324. ch = [ch]
  325. layers, save, c2 = [], [], ch[-1] # layers, savelist, ch out
  326. for i, (f, n, m, args) in enumerate(d['backbone'] + d['head']): # from, number, module, args
  327. m = getattr(torch.nn, m[3:]) if 'nn.' in m else globals()[m] # get module
  328. for j, a in enumerate(args):
  329. if isinstance(a, str):
  330. with contextlib.suppress(ValueError):
  331. args[j] = locals()[a] if a in locals() else ast.literal_eval(a)
  332. n = n_ = max(round(n * depth), 1) if n > 1 else n # depth gain
  333. if m in (Conv, ConvTranspose, Bottleneck, SPPF, DWConv, C2f, nn.ConvTranspose2d):
  334. c1, c2 = ch[f], args[0]
  335. if c2 != nc: # if c2 not equal to number of classes (i.e. for Classify() output)
  336. c2 = make_divisible(min(c2, max_channels) * width, 8)
  337. args = [c1, c2, *args[1:]]
  338. if m in (C2f,):
  339. args.insert(2, n) # number of repeats
  340. n = 1
  341. elif m is nn.BatchNorm2d:
  342. args = [ch[f]]
  343. elif m is Concat:
  344. c2 = sum(ch[x] for x in f)
  345. elif m in (Detect,):
  346. args.append([ch[x] for x in f])
  347. else:
  348. c2 = ch[f]
  349. m_ = nn.Sequential(*(m(*args) for _ in range(n))) if n > 1 else m(*args) # module
  350. t = str(m)[8:-2].replace('__main__.', '') # module type
  351. m.np = sum(x.numel() for x in m_.parameters()) # number params
  352. m_.i, m_.f, m_.type = i, f, t # attach index, 'from' index, type
  353. save.extend(x % i for x in ([f] if isinstance(f, int) else f) if x != -1) # append to savelist
  354. layers.append(m_)
  355. if i == 0:
  356. ch = []
  357. ch.append(c2)
  358. return nn.Sequential(*layers), sorted(save)