warp_mls.py 6.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164
  1. # copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import numpy as np
  15. class WarpMLS:
  16. def __init__(self, src, src_pts, dst_pts, dst_w, dst_h, trans_ratio=1.):
  17. self.src = src
  18. self.src_pts = src_pts
  19. self.dst_pts = dst_pts
  20. self.pt_count = len(self.dst_pts)
  21. self.dst_w = dst_w
  22. self.dst_h = dst_h
  23. self.trans_ratio = trans_ratio
  24. self.grid_size = 100
  25. self.rdx = np.zeros((self.dst_h, self.dst_w))
  26. self.rdy = np.zeros((self.dst_h, self.dst_w))
  27. @staticmethod
  28. def __bilinear_interp(x, y, v11, v12, v21, v22):
  29. return (v11 * (1 - y) + v12 * y) * (1 - x) + (v21 *
  30. (1 - y) + v22 * y) * x
  31. def generate(self):
  32. self.calc_delta()
  33. return self.gen_img()
  34. def calc_delta(self):
  35. w = np.zeros(self.pt_count, dtype=np.float32)
  36. if self.pt_count < 2:
  37. return
  38. i = 0
  39. while 1:
  40. if self.dst_w <= i < self.dst_w + self.grid_size - 1:
  41. i = self.dst_w - 1
  42. elif i >= self.dst_w:
  43. break
  44. j = 0
  45. while 1:
  46. if self.dst_h <= j < self.dst_h + self.grid_size - 1:
  47. j = self.dst_h - 1
  48. elif j >= self.dst_h:
  49. break
  50. sw = 0
  51. swp = np.zeros(2, dtype=np.float32)
  52. swq = np.zeros(2, dtype=np.float32)
  53. new_pt = np.zeros(2, dtype=np.float32)
  54. cur_pt = np.array([i, j], dtype=np.float32)
  55. k = 0
  56. for k in range(self.pt_count):
  57. if i == self.dst_pts[k][0] and j == self.dst_pts[k][1]:
  58. break
  59. w[k] = 1. / (
  60. (i - self.dst_pts[k][0]) * (i - self.dst_pts[k][0]) +
  61. (j - self.dst_pts[k][1]) * (j - self.dst_pts[k][1]))
  62. sw += w[k]
  63. swp = swp + w[k] * np.array(self.dst_pts[k])
  64. swq = swq + w[k] * np.array(self.src_pts[k])
  65. if k == self.pt_count - 1:
  66. pstar = 1 / sw * swp
  67. qstar = 1 / sw * swq
  68. miu_s = 0
  69. for k in range(self.pt_count):
  70. if i == self.dst_pts[k][0] and j == self.dst_pts[k][1]:
  71. continue
  72. pt_i = self.dst_pts[k] - pstar
  73. miu_s += w[k] * np.sum(pt_i * pt_i)
  74. cur_pt -= pstar
  75. cur_pt_j = np.array([-cur_pt[1], cur_pt[0]])
  76. for k in range(self.pt_count):
  77. if i == self.dst_pts[k][0] and j == self.dst_pts[k][1]:
  78. continue
  79. pt_i = self.dst_pts[k] - pstar
  80. pt_j = np.array([-pt_i[1], pt_i[0]])
  81. tmp_pt = np.zeros(2, dtype=np.float32)
  82. tmp_pt[0] = np.sum(pt_i * cur_pt) * self.src_pts[k][0] - \
  83. np.sum(pt_j * cur_pt) * self.src_pts[k][1]
  84. tmp_pt[1] = -np.sum(pt_i * cur_pt_j) * self.src_pts[k][0] + \
  85. np.sum(pt_j * cur_pt_j) * self.src_pts[k][1]
  86. tmp_pt *= (w[k] / miu_s)
  87. new_pt += tmp_pt
  88. new_pt += qstar
  89. else:
  90. new_pt = self.src_pts[k]
  91. self.rdx[j, i] = new_pt[0] - i
  92. self.rdy[j, i] = new_pt[1] - j
  93. j += self.grid_size
  94. i += self.grid_size
  95. def gen_img(self):
  96. src_h, src_w = self.src.shape[:2]
  97. dst = np.zeros_like(self.src, dtype=np.float32)
  98. for i in np.arange(0, self.dst_h, self.grid_size):
  99. for j in np.arange(0, self.dst_w, self.grid_size):
  100. ni = i + self.grid_size
  101. nj = j + self.grid_size
  102. w = h = self.grid_size
  103. if ni >= self.dst_h:
  104. ni = self.dst_h - 1
  105. h = ni - i + 1
  106. if nj >= self.dst_w:
  107. nj = self.dst_w - 1
  108. w = nj - j + 1
  109. di = np.reshape(np.arange(h), (-1, 1))
  110. dj = np.reshape(np.arange(w), (1, -1))
  111. delta_x = self.__bilinear_interp(
  112. di / h, dj / w, self.rdx[i, j], self.rdx[i, nj],
  113. self.rdx[ni, j], self.rdx[ni, nj])
  114. delta_y = self.__bilinear_interp(
  115. di / h, dj / w, self.rdy[i, j], self.rdy[i, nj],
  116. self.rdy[ni, j], self.rdy[ni, nj])
  117. nx = j + dj + delta_x * self.trans_ratio
  118. ny = i + di + delta_y * self.trans_ratio
  119. nx = np.clip(nx, 0, src_w - 1)
  120. ny = np.clip(ny, 0, src_h - 1)
  121. nxi = np.array(np.floor(nx), dtype=np.int32)
  122. nyi = np.array(np.floor(ny), dtype=np.int32)
  123. nxi1 = np.array(np.ceil(nx), dtype=np.int32)
  124. nyi1 = np.array(np.ceil(ny), dtype=np.int32)
  125. if len(self.src.shape) == 3:
  126. x = np.tile(np.expand_dims(ny - nyi, axis=-1), (1, 1, 3))
  127. y = np.tile(np.expand_dims(nx - nxi, axis=-1), (1, 1, 3))
  128. else:
  129. x = ny - nyi
  130. y = nx - nxi
  131. dst[i:i + h, j:j + w] = self.__bilinear_interp(
  132. x, y, self.src[nyi, nxi], self.src[nyi, nxi1],
  133. self.src[nyi1, nxi], self.src[nyi1, nxi1])
  134. dst = np.clip(dst, 0, 255)
  135. dst = np.array(dst, dtype=np.uint8)
  136. return dst