TPS.py 8.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170
  1. from collections import OrderedDict
  2. import numpy as np
  3. import torch
  4. import torch.nn as nn
  5. import torch.nn.functional as F
  6. class TPS(nn.Module):
  7. """ Rectification Network of RARE, namely TPS based STN """
  8. def __init__(self, F, I_size, I_r_size, I_channel_num=1):
  9. """ Based on RARE TPS
  10. input:
  11. batch_I: Batch Input Image [batch_size x I_channel_num x I_height x I_width]
  12. I_size : (height, width) of the input image I
  13. I_r_size : (height, width) of the rectified image I_r
  14. I_channel_num : the number of channels of the input image I
  15. output:
  16. batch_I_r: rectified image [batch_size x I_channel_num x I_r_height x I_r_width]
  17. """
  18. super(TPS, self).__init__()
  19. self.F = F
  20. self.I_size = I_size
  21. self.I_r_size = I_r_size # = (I_r_height, I_r_width)
  22. self.I_channel_num = I_channel_num
  23. self.LocalizationNetwork = LocalizationNetwork(self.F, self.I_channel_num)
  24. self.GridGenerator = GridGenerator(self.F, self.I_r_size)
  25. def forward(self, batch_I):
  26. batch_C_prime = self.LocalizationNetwork(batch_I) # batch_size x K x 2
  27. build_P_prime = self.GridGenerator.build_P_prime(batch_C_prime) # batch_size x n (= I_r_width x I_r_height) x 2
  28. build_P_prime_reshape = build_P_prime.reshape([build_P_prime.size(0), self.I_r_size[0], self.I_r_size[1], 2])
  29. if torch.__version__ > "1.2.0":
  30. batch_I_r = F.grid_sample(batch_I, build_P_prime_reshape, padding_mode='border', align_corners=True)
  31. else:
  32. batch_I_r = F.grid_sample(batch_I, build_P_prime_reshape, padding_mode='border')
  33. return batch_I_r
  34. class LocalizationNetwork(nn.Module):
  35. """ Localization Network of RARE, which predicts C' (K x 2) from I (I_width x I_height) """
  36. def __init__(self, F, I_channel_num):
  37. super(LocalizationNetwork, self).__init__()
  38. self.F = F
  39. self.I_channel_num = I_channel_num
  40. self.conv = nn.Sequential(
  41. nn.Conv2d(in_channels=self.I_channel_num, out_channels=64,
  42. kernel_size=3, stride=1, padding=1, bias=False), nn.BatchNorm2d(64), nn.ReLU(True),
  43. nn.MaxPool2d(2, 2), # batch_size x 64 x I_height/2 x I_width/2
  44. nn.Conv2d(64, 128, 3, 1, 1, bias=False), nn.BatchNorm2d(128), nn.ReLU(True),
  45. nn.MaxPool2d(2, 2), # batch_size x 128 x I_height/4 x I_width/4
  46. nn.Conv2d(128, 256, 3, 1, 1, bias=False), nn.BatchNorm2d(256), nn.ReLU(True),
  47. nn.MaxPool2d(2, 2), # batch_size x 256 x I_height/8 x I_width/8
  48. nn.Conv2d(256, 512, 3, 1, 1, bias=False), nn.BatchNorm2d(512), nn.ReLU(True),
  49. nn.AdaptiveAvgPool2d(1) # batch_size x 512
  50. )
  51. self.localization_fc1 = nn.Sequential(nn.Linear(512, 256), nn.ReLU(True))
  52. self.localization_fc2 = nn.Linear(256, self.F * 2)
  53. # Init fc2 in LocalizationNetwork
  54. self.localization_fc2.weight.data.fill_(0)
  55. """ see RARE paper Fig. 6 (a) """
  56. ctrl_pts_x = np.linspace(-1.0, 1.0, int(F / 2))
  57. ctrl_pts_y_top = np.linspace(0.0, -1.0, num=int(F / 2))
  58. ctrl_pts_y_bottom = np.linspace(1.0, 0.0, num=int(F / 2))
  59. ctrl_pts_top = np.stack([ctrl_pts_x, ctrl_pts_y_top], axis=1)
  60. ctrl_pts_bottom = np.stack([ctrl_pts_x, ctrl_pts_y_bottom], axis=1)
  61. initial_bias = np.concatenate([ctrl_pts_top, ctrl_pts_bottom], axis=0)
  62. self.localization_fc2.bias.data = torch.from_numpy(initial_bias).float().view(-1)
  63. def forward(self, batch_I):
  64. """
  65. input: batch_I : Batch Input Image [batch_size x I_channel_num x I_height x I_width]
  66. output: batch_C_prime : Predicted coordinates of fiducial points for input batch [batch_size x F x 2]
  67. """
  68. batch_size = batch_I.size(0)
  69. features = self.conv(batch_I).view(batch_size, -1)
  70. batch_C_prime = self.localization_fc2(self.localization_fc1(features)).view(batch_size, self.F, 2)
  71. return batch_C_prime
  72. class GridGenerator(nn.Module):
  73. """ Grid Generator of RARE, which produces P_prime by multipling T with P """
  74. def __init__(self, F, I_r_size):
  75. """ Generate P_hat and inv_delta_C for later """
  76. super(GridGenerator, self).__init__()
  77. self.eps = 1e-6
  78. self.I_r_height, self.I_r_width = I_r_size
  79. self.F = F
  80. self.C = self._build_C(self.F) # F x 2
  81. self.P = self._build_P(self.I_r_width, self.I_r_height)
  82. # for multi-gpu, you need register buffer
  83. # self.register_buffer("inv_delta_C", torch.tensor(self._build_inv_delta_C(self.F, self.C)).float()) # F+3 x F+3
  84. # self.register_buffer("P_hat", torch.tensor(self._build_P_hat(self.F, self.C, self.P)).float()) # n x F+3
  85. # for fine-tuning with different image width, you may use below instead of self.register_buffer
  86. self.inv_delta_C = torch.tensor(self._build_inv_delta_C(self.F, self.C)).float() # F+3 x F+3
  87. self.P_hat = torch.tensor(self._build_P_hat(self.F, self.C, self.P)).float() # n x F+3
  88. def _build_C(self, F):
  89. """ Return coordinates of fiducial points in I_r; C """
  90. ctrl_pts_x = np.linspace(-1.0, 1.0, int(F / 2))
  91. ctrl_pts_y_top = -1 * np.ones(int(F / 2))
  92. ctrl_pts_y_bottom = np.ones(int(F / 2))
  93. ctrl_pts_top = np.stack([ctrl_pts_x, ctrl_pts_y_top], axis=1)
  94. ctrl_pts_bottom = np.stack([ctrl_pts_x, ctrl_pts_y_bottom], axis=1)
  95. C = np.concatenate([ctrl_pts_top, ctrl_pts_bottom], axis=0)
  96. return C # F x 2
  97. def _build_inv_delta_C(self, F, C):
  98. """ Return inv_delta_C which is needed to calculate T """
  99. hat_C = np.zeros((F, F), dtype=float) # F x F
  100. for i in range(0, F):
  101. for j in range(i, F):
  102. r = np.linalg.norm(C[i] - C[j])
  103. hat_C[i, j] = r
  104. hat_C[j, i] = r
  105. np.fill_diagonal(hat_C, 1)
  106. hat_C = (hat_C ** 2) * np.log(hat_C)
  107. # print(C.shape, hat_C.shape)
  108. delta_C = np.concatenate( # F+3 x F+3
  109. [
  110. np.concatenate([np.ones((F, 1)), C, hat_C], axis=1), # F x F+3
  111. np.concatenate([np.zeros((2, 3)), np.transpose(C)], axis=1), # 2 x F+3
  112. np.concatenate([np.zeros((1, 3)), np.ones((1, F))], axis=1) # 1 x F+3
  113. ],
  114. axis=0
  115. )
  116. inv_delta_C = np.linalg.inv(delta_C)
  117. return inv_delta_C # F+3 x F+3
  118. def _build_P(self, I_r_width, I_r_height):
  119. I_r_grid_x = (np.arange(-I_r_width, I_r_width, 2) + 1.0) / I_r_width # self.I_r_width
  120. I_r_grid_y = (np.arange(-I_r_height, I_r_height, 2) + 1.0) / I_r_height # self.I_r_height
  121. P = np.stack( # self.I_r_width x self.I_r_height x 2
  122. np.meshgrid(I_r_grid_x, I_r_grid_y),
  123. axis=2
  124. )
  125. return P.reshape([-1, 2]) # n (= self.I_r_width x self.I_r_height) x 2
  126. def _build_P_hat(self, F, C, P):
  127. n = P.shape[0] # n (= self.I_r_width x self.I_r_height)
  128. P_tile = np.tile(np.expand_dims(P, axis=1), (1, F, 1)) # n x 2 -> n x 1 x 2 -> n x F x 2
  129. C_tile = np.expand_dims(C, axis=0) # 1 x F x 2
  130. P_diff = P_tile - C_tile # n x F x 2
  131. rbf_norm = np.linalg.norm(P_diff, ord=2, axis=2, keepdims=False) # n x F
  132. rbf = np.multiply(np.square(rbf_norm), np.log(rbf_norm + self.eps)) # n x F
  133. P_hat = np.concatenate([np.ones((n, 1)), P, rbf], axis=1)
  134. return P_hat # n x F+3
  135. def build_P_prime(self, batch_C_prime):
  136. """ Generate Grid from batch_C_prime [batch_size x F x 2] """
  137. device = batch_C_prime.device
  138. batch_size = batch_C_prime.size(0)
  139. batch_inv_delta_C = self.inv_delta_C.repeat(batch_size, 1, 1).to(device)
  140. batch_P_hat = self.P_hat.repeat(batch_size, 1, 1).to(device)
  141. batch_C_prime_with_zeros = torch.cat(
  142. (
  143. batch_C_prime,
  144. torch.zeros(batch_size, 3, 2).float().to(device)
  145. ), dim=1
  146. ) # batch_size x F+3 x 2
  147. batch_T = torch.bmm(batch_inv_delta_C, batch_C_prime_with_zeros) # batch_size x F+3 x 2
  148. batch_P_prime = torch.bmm(batch_P_hat, batch_T) # batch_size x n x 2
  149. return batch_P_prime # batch_size x n x 2