ConvNext.py 6.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166
  1. from functools import partial
  2. import logging
  3. import os
  4. import torch
  5. import torch.nn as nn
  6. import torch.nn.functional as F
  7. from torchocr.networks.backbones.Transformer import DropPath
  8. class Block(nn.Module):
  9. r""" ConvNeXt Block. There are two equivalent implementations:
  10. (1) DwConv -> LayerNorm (channels_first) -> 1x1 Conv -> GELU -> 1x1 Conv; all in (N, C, H, W)
  11. (2) DwConv -> Permute to (N, H, W, C); LayerNorm (channels_last) -> Linear -> GELU -> Linear; Permute back
  12. We use (2) as we find it slightly faster in PyTorch
  13. Args:
  14. dim (int): Number of input channels.
  15. drop_path (float): Stochastic depth rate. Default: 0.0
  16. layer_scale_init_value (float): Init value for Layer Scale. Default: 1e-6.
  17. """
  18. def __init__(self, dim, drop_path=0., layer_scale_init_value=1e-6):
  19. super().__init__()
  20. self.dwconv = nn.Conv2d(dim, dim, kernel_size=7, padding=3, groups=dim) # depthwise conv
  21. self.norm = LayerNorm(dim, eps=1e-6)
  22. self.pwconv1 = nn.Linear(dim, 4 * dim) # pointwise/1x1 convs, implemented with linear layers
  23. self.act = nn.GELU()
  24. self.pwconv2 = nn.Linear(4 * dim, dim)
  25. self.gamma = nn.Parameter(layer_scale_init_value * torch.ones((dim)),
  26. requires_grad=True) if layer_scale_init_value > 0 else None
  27. self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
  28. def forward(self, x):
  29. input = x
  30. x = self.dwconv(x)
  31. x = x.permute(0, 2, 3, 1) # (N, C, H, W) -> (N, H, W, C)
  32. x = self.norm(x)
  33. x = self.pwconv1(x)
  34. x = self.act(x)
  35. x = self.pwconv2(x)
  36. if self.gamma is not None:
  37. x = self.gamma * x
  38. x = x.permute(0, 3, 1, 2) # (N, H, W, C) -> (N, C, H, W)
  39. x = input + self.drop_path(x)
  40. return x
  41. class ConvNeXt(nn.Module):
  42. r""" ConvNeXt
  43. A PyTorch impl of : `A ConvNet for the 2020s` -
  44. https://arxiv.org/pdf/2201.03545.pdf
  45. Args:
  46. in_chans (int): Number of input image channels. Default: 3
  47. num_classes (int): Number of classes for classification head. Default: 1000
  48. depths (tuple(int)): Number of blocks at each stage. Default: [3, 3, 9, 3]
  49. dims (int): Feature dimension at each stage. Default: [96, 192, 384, 768]
  50. drop_path_rate (float): Stochastic depth rate. Default: 0.
  51. layer_scale_init_value (float): Init value for Layer Scale. Default: 1e-6.
  52. head_init_scale (float): Init scaling value for classifier weights and biases. Default: 1.
  53. """
  54. def __init__(self, in_chans=3, depths=[3, 3, 9, 3], dims=[96, 192, 384, 768],
  55. drop_path_rate=0.4, layer_scale_init_value=1.0, out_indices=[0, 1, 2, 3], **kwargs
  56. ):
  57. super().__init__()
  58. self.downsample_layers = nn.ModuleList() # stem and 3 intermediate downsampling conv layers
  59. stem = nn.Sequential(
  60. nn.Conv2d(in_chans, dims[0], kernel_size=4, stride=4),
  61. LayerNorm(dims[0], eps=1e-6, data_format="channels_first")
  62. )
  63. self.downsample_layers.append(stem)
  64. for i in range(3):
  65. downsample_layer = nn.Sequential(
  66. LayerNorm(dims[i], eps=1e-6, data_format="channels_first"),
  67. nn.Conv2d(dims[i], dims[i + 1], kernel_size=2, stride=2),
  68. )
  69. self.downsample_layers.append(downsample_layer)
  70. self.stages = nn.ModuleList() # 4 feature resolution stages, each consisting of multiple residual blocks
  71. self.pretrained = kwargs.get('pretrained', True)
  72. dp_rates = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]
  73. self.out_channels = [96, 192, 384, 768]
  74. cur = 0
  75. for i in range(4):
  76. stage = nn.Sequential(
  77. *[Block(dim=dims[i], drop_path=dp_rates[cur + j],
  78. layer_scale_init_value=layer_scale_init_value) for j in range(depths[i])]
  79. )
  80. self.stages.append(stage)
  81. cur += depths[i]
  82. self.out_indices = out_indices
  83. norm_layer = partial(LayerNorm, eps=1e-6, data_format="channels_first")
  84. for i_layer in range(4):
  85. layer = norm_layer(dims[i_layer])
  86. layer_name = f'norm{i_layer}'
  87. self.add_module(layer_name, layer)
  88. if self.pretrained:
  89. ckpt_path = f'./weights/convnext_tiny_1k_512x512.pth'
  90. logger = logging.getLogger('torchocr')
  91. if os.path.exists(ckpt_path):
  92. logger.info('load convnext weights')
  93. self.load_state_dict(torch.load(ckpt_path), strict=True)
  94. else:
  95. logger.info(f'{ckpt_path} not exists')
  96. self.apply(self._init_weights)
  97. else:
  98. self.apply(self._init_weights)
  99. def _init_weights(self, m):
  100. if isinstance(m, nn.Linear):
  101. nn.init.trunc_normal_(m.weight, std=.02)
  102. if isinstance(m, nn.Linear) and m.bias is not None:
  103. nn.init.constant_(m.bias, 0)
  104. elif isinstance(m, nn.LayerNorm):
  105. nn.init.constant_(m.bias, 0)
  106. nn.init.constant_(m.weight, 1.0)
  107. def forward_features(self, x):
  108. outs = []
  109. for i in range(4):
  110. x = self.downsample_layers[i](x)
  111. x = self.stages[i](x)
  112. if i in self.out_indices:
  113. norm_layer = getattr(self, f'norm{i}')
  114. x_out = norm_layer(x)
  115. outs.append(x_out)
  116. return tuple(outs)
  117. def forward(self, x):
  118. x = self.forward_features(x)
  119. return x
  120. class LayerNorm(nn.Module):
  121. r""" LayerNorm that supports two data formats: channels_last (default) or channels_first.
  122. The ordering of the dimensions in the inputs. channels_last corresponds to inputs with
  123. shape (batch_size, height, width, channels) while channels_first corresponds to inputs
  124. with shape (batch_size, channels, height, width).
  125. """
  126. def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last"):
  127. super().__init__()
  128. self.weight = nn.Parameter(torch.ones(normalized_shape))
  129. self.bias = nn.Parameter(torch.zeros(normalized_shape))
  130. self.eps = eps
  131. self.data_format = data_format
  132. if self.data_format not in ["channels_last", "channels_first"]:
  133. raise NotImplementedError
  134. self.normalized_shape = (normalized_shape,)
  135. def forward(self, x):
  136. if self.data_format == "channels_last":
  137. return F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
  138. elif self.data_format == "channels_first":
  139. u = x.mean(1, keepdim=True)
  140. s = (x - u).pow(2).mean(1, keepdim=True)
  141. x = (x - u) / torch.sqrt(s + self.eps)
  142. x = self.weight[:, None, None] * x + self.bias[:, None, None]
  143. return x