augment.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385
  1. # -*- coding: utf-8 -*-
  2. # @Time : 2019/8/23 21:52
  3. # @Author : zhoujun
  4. import math
  5. import numbers
  6. import random
  7. import cv2
  8. import numpy as np
  9. from skimage.util import random_noise
  10. __all__ = ['RandomNoise', 'RandomResize', 'RandomScale', 'ResizeShortSize', 'RandomRotateImgBox', 'HorizontalFlip',
  11. 'VerticallFlip', 'ResizeFixedSize', 'ResizeLongSize']
  12. class RandomNoise:
  13. def __init__(self, random_rate):
  14. self.random_rate = random_rate
  15. def __call__(self, data: dict):
  16. """
  17. 对图片加噪声
  18. :param data: {'img':,'text_polys':,'texts':,'ignore_tags':}
  19. :return:
  20. """
  21. if random.random() > self.random_rate:
  22. return data
  23. data['img'] = (random_noise(data['img'], mode='gaussian', clip=True) * 255).astype(data['img'].dtype)
  24. return data
  25. class RandomScale:
  26. def __init__(self, scales, random_rate):
  27. """
  28. :param scales: 尺度
  29. :param ramdon_rate: 随机系数
  30. :return:
  31. """
  32. self.random_rate = random_rate
  33. self.scales = scales
  34. def __call__(self, data: dict) -> dict:
  35. """
  36. 从scales中随机选择一个尺度,对图片和文本框进行缩放
  37. :param data: {'img':,'text_polys':,'texts':,'ignore_tags':}
  38. :return:
  39. """
  40. if random.random() > self.random_rate:
  41. return data
  42. im = data['img']
  43. text_polys = data['text_polys']
  44. tmp_text_polys = text_polys.copy()
  45. rd_scale = float(np.random.choice(self.scales))
  46. im = cv2.resize(im, dsize=None, fx=rd_scale, fy=rd_scale)
  47. tmp_text_polys *= rd_scale
  48. data['img'] = im
  49. data['text_polys'] = tmp_text_polys
  50. return data
  51. class RandomRotateImgBox:
  52. def __init__(self, degrees, random_rate, same_size=False):
  53. """
  54. :param degrees: 角度,可以是一个数值或者list
  55. :param ramdon_rate: 随机系数
  56. :param same_size: 是否保持和原图一样大
  57. :return:
  58. """
  59. if isinstance(degrees, numbers.Number):
  60. if degrees < 0:
  61. raise ValueError("If degrees is a single number, it must be positive.")
  62. degrees = (-degrees, degrees)
  63. elif isinstance(degrees, list) or isinstance(degrees, tuple) or isinstance(degrees, np.ndarray):
  64. if len(degrees) != 2:
  65. raise ValueError("If degrees is a sequence, it must be of len 2.")
  66. degrees = degrees
  67. else:
  68. raise Exception('degrees must in Number or list or tuple or np.ndarray')
  69. self.degrees = degrees
  70. self.same_size = same_size
  71. self.random_rate = random_rate
  72. def __call__(self, data: dict) -> dict:
  73. """
  74. 从scales中随机选择一个尺度,对图片和文本框进行缩放
  75. :param data: {'img':,'text_polys':,'texts':,'ignore_tags':}
  76. :return:
  77. """
  78. if random.random() > self.random_rate:
  79. return data
  80. im = data['img']
  81. text_polys = data['text_polys']
  82. # ---------------------- 旋转图像 ----------------------
  83. w = im.shape[1]
  84. h = im.shape[0]
  85. angle = np.random.uniform(self.degrees[0], self.degrees[1])
  86. if self.same_size:
  87. nw = w
  88. nh = h
  89. else:
  90. # 角度变弧度
  91. rangle = np.deg2rad(angle)
  92. # 计算旋转之后图像的w, h
  93. nw = (abs(np.sin(rangle) * h) + abs(np.cos(rangle) * w))
  94. nh = (abs(np.cos(rangle) * h) + abs(np.sin(rangle) * w))
  95. # 构造仿射矩阵
  96. rot_mat = cv2.getRotationMatrix2D((nw * 0.5, nh * 0.5), angle, 1)
  97. # 计算原图中心点到新图中心点的偏移量
  98. rot_move = np.dot(rot_mat, np.array([(nw - w) * 0.5, (nh - h) * 0.5, 0]))
  99. # 更新仿射矩阵
  100. rot_mat[0, 2] += rot_move[0]
  101. rot_mat[1, 2] += rot_move[1]
  102. # 仿射变换
  103. rot_img = cv2.warpAffine(im, rot_mat, (int(math.ceil(nw)), int(math.ceil(nh))), flags=cv2.INTER_LANCZOS4)
  104. # ---------------------- 矫正bbox坐标 ----------------------
  105. # rot_mat是最终的旋转矩阵
  106. # 获取原始bbox的四个中点,然后将这四个点转换到旋转后的坐标系下
  107. rot_text_polys = list()
  108. for bbox in text_polys:
  109. point1 = np.dot(rot_mat, np.array([bbox[0, 0], bbox[0, 1], 1]))
  110. point2 = np.dot(rot_mat, np.array([bbox[1, 0], bbox[1, 1], 1]))
  111. point3 = np.dot(rot_mat, np.array([bbox[2, 0], bbox[2, 1], 1]))
  112. point4 = np.dot(rot_mat, np.array([bbox[3, 0], bbox[3, 1], 1]))
  113. rot_text_polys.append([point1, point2, point3, point4])
  114. data['img'] = rot_img
  115. data['text_polys'] = np.array(rot_text_polys)
  116. return data
  117. class RandomResize:
  118. def __init__(self, size, random_rate, keep_ratio=False):
  119. """
  120. :param input_size: resize尺寸,数字或者list的形式,如果为list形式,就是[w,h]
  121. :param ramdon_rate: 随机系数
  122. :param keep_ratio: 是否保持长宽比
  123. :return:
  124. """
  125. if isinstance(size, numbers.Number):
  126. if size < 0:
  127. raise ValueError("If input_size is a single number, it must be positive.")
  128. size = (size, size)
  129. elif isinstance(size, list) or isinstance(size, tuple) or isinstance(size, np.ndarray):
  130. if len(size) != 2:
  131. raise ValueError("If input_size is a sequence, it must be of len 2.")
  132. size = (size[0], size[1])
  133. else:
  134. raise Exception('input_size must in Number or list or tuple or np.ndarray')
  135. self.size = size
  136. self.keep_ratio = keep_ratio
  137. self.random_rate = random_rate
  138. def __call__(self, data: dict) -> dict:
  139. """
  140. 从scales中随机选择一个尺度,对图片和文本框进行缩放
  141. :param data: {'img':,'text_polys':,'texts':,'ignore_tags':}
  142. :return:
  143. """
  144. if random.random() > self.random_rate:
  145. return data
  146. im = data['img']
  147. text_polys = data['text_polys']
  148. if self.keep_ratio:
  149. # 将图片短边pad到和长边一样
  150. h, w, c = im.shape
  151. max_h = max(h, self.size[0])
  152. max_w = max(w, self.size[1])
  153. im_padded = np.zeros((max_h, max_w, c), dtype=np.uint8)
  154. im_padded[:h, :w] = im.copy()
  155. im = im_padded
  156. text_polys = text_polys.astype(np.float32)
  157. h, w, _ = im.shape
  158. im = cv2.resize(im, self.size)
  159. w_scale = self.size[0] / float(w)
  160. h_scale = self.size[1] / float(h)
  161. text_polys[:, :, 0] *= w_scale
  162. text_polys[:, :, 1] *= h_scale
  163. data['img'] = im
  164. data['text_polys'] = text_polys
  165. return data
  166. def resize_image(img, short_size):
  167. height, width, _ = img.shape
  168. if height < width:
  169. new_height = short_size
  170. new_width = new_height / height * width
  171. else:
  172. new_width = short_size
  173. new_height = new_width / width * height
  174. new_height = int(round(new_height / 32) * 32)
  175. new_width = int(round(new_width / 32) * 32)
  176. resized_img = cv2.resize(img, (new_width, new_height))
  177. return resized_img, (new_width / width, new_height / height)
  178. class ResizeShortSize:
  179. def __init__(self, short_size, resize_text_polys=True):
  180. """
  181. :param size: resize尺寸,数字或者list的形式,如果为list形式,就是[w,h]
  182. :return:
  183. """
  184. self.short_size = short_size
  185. self.resize_text_polys = resize_text_polys
  186. def __call__(self, data: dict) -> dict:
  187. """
  188. 对图片和文本框进行缩放
  189. :param data: {'img':,'text_polys':,'texts':,'ignore_tags':}
  190. :return:
  191. """
  192. im = data['img']
  193. text_polys = data['text_polys']
  194. h, w, _ = im.shape
  195. if min(h, w) < self.short_size:
  196. if h < w:
  197. ratio = float(self.short_size) / h
  198. else:
  199. ratio = float(self.short_size) / w
  200. else:
  201. ratio = 1.
  202. resize_h = int(h * ratio)
  203. resize_w = int(w * ratio)
  204. resize_h = max(int(round(resize_h / 32) * 32), 32)
  205. resize_w = max(int(round(resize_w / 32) * 32), 32)
  206. img = cv2.resize(im, (int(resize_w), int(resize_h)))
  207. if self.resize_text_polys:
  208. text_polys[:, 0] *= ratio
  209. text_polys[:, 1] *= ratio
  210. data['img'] = img
  211. data['text_polys'] = text_polys
  212. return data
  213. class HorizontalFlip:
  214. def __init__(self, random_rate):
  215. """
  216. :param random_rate: 随机系数
  217. """
  218. self.random_rate = random_rate
  219. def __call__(self, data: dict) -> dict:
  220. """
  221. 从scales中随机选择一个尺度,对图片和文本框进行缩放
  222. :param data: {'img':,'text_polys':,'texts':,'ignore_tags':}
  223. :return:
  224. """
  225. if random.random() > self.random_rate:
  226. return data
  227. im = data['img']
  228. text_polys = data['text_polys']
  229. flip_text_polys = text_polys.copy()
  230. flip_im = cv2.flip(im, 1)
  231. h, w, _ = flip_im.shape
  232. flip_text_polys[:, :, 0] = w - flip_text_polys[:, :, 0]
  233. data['img'] = flip_im
  234. data['text_polys'] = flip_text_polys
  235. return data
  236. class VerticallFlip:
  237. def __init__(self, random_rate):
  238. """
  239. :param random_rate: 随机系数
  240. """
  241. self.random_rate = random_rate
  242. def __call__(self, data: dict) -> dict:
  243. """
  244. 从scales中随机选择一个尺度,对图片和文本框进行缩放
  245. :param data: {'img':,'text_polys':,'texts':,'ignore_tags':}
  246. :return:
  247. """
  248. if random.random() > self.random_rate:
  249. return data
  250. im = data['img']
  251. text_polys = data['text_polys']
  252. flip_text_polys = text_polys.copy()
  253. flip_im = cv2.flip(im, 0)
  254. h, w, _ = flip_im.shape
  255. flip_text_polys[:, :, 1] = h - flip_text_polys[:, :, 1]
  256. data['img'] = flip_im
  257. data['text_polys'] = flip_text_polys
  258. return data
  259. class ResizeFixedSize:
  260. def __init__(self, short_size, resize_text_polys=True):
  261. """
  262. :param size: resize尺寸,数字或者list的形式,如果为list形式,就是[w,h]
  263. :return:
  264. """
  265. self.short_size = short_size
  266. self.resize_text_polys = resize_text_polys
  267. def __call__(self, data: dict) -> dict:
  268. """
  269. 对图片和文本框进行缩放
  270. :param data: {'img':,'text_polys':,'texts':,'ignore_tags':}
  271. :return:
  272. """
  273. im = data['img']
  274. text_polys = data['text_polys']
  275. h, w, _ = im.shape
  276. if min(h, w) < self.short_size:
  277. if h < w:
  278. ratio = float(self.short_size) / h
  279. else:
  280. ratio = float(self.short_size) / w
  281. else:
  282. ratio = 1.
  283. resize_h = int(h * ratio)
  284. resize_w = int(w * ratio)
  285. resize_h = max(int(round(resize_h / 32) * 32), 32)
  286. resize_w = max(int(round(resize_w / 32) * 32), 32)
  287. try:
  288. if int(resize_w) <= 0 or int(resize_h) <= 0:
  289. return None, (None, None)
  290. img = cv2.resize(im, (int(resize_w), int(resize_h)))
  291. except:
  292. print(img.shape, resize_w, resize_h)
  293. import sys
  294. sys.exit(0)
  295. ratio_h = resize_h / float(h)
  296. ratio_w = resize_w / float(w)
  297. if self.resize_text_polys:
  298. text_polys[:, 0] *= ratio_h
  299. text_polys[:, 1] *= ratio_w
  300. data['img'] = img
  301. data['text_polys'] = text_polys
  302. return data
  303. class ResizeLongSize:
  304. def __init__(self, long_size, resize_text_polys=True): # short_size,
  305. """
  306. :param size: resize尺寸,数字或者list的形式,如果为list形式,就是[w,h]
  307. :return:
  308. """
  309. # self.short_size = short_size
  310. self.long_size = long_size
  311. self.resize_text_polys = resize_text_polys
  312. def __call__(self, data: dict) -> dict:
  313. """
  314. 对图片和文本框进行缩放
  315. :param data: {'img':,'text_polys':,'texts':,'ignore_tags':}
  316. :return:
  317. """
  318. im = data['img']
  319. text_polys = data['text_polys']
  320. h, w, _ = im.shape
  321. if max(h, w) > self.long_size:
  322. if h < w:
  323. ratio = float(self.long_size) / w
  324. else:
  325. ratio = float(self.long_size) / h
  326. else:
  327. ratio = 1.
  328. resize_h = int(h * ratio)
  329. resize_w = int(w * ratio)
  330. resize_h = max(int(round(resize_h / 32) * 32), 32)
  331. resize_w = max(int(round(resize_w / 32) * 32), 32)
  332. img = cv2.resize(im, (int(resize_w), int(resize_h)))
  333. if self.resize_text_polys:
  334. text_polys[:, 0] *= ratio
  335. text_polys[:, 1] *= ratio
  336. data['img'] = img
  337. data['text_polys'] = text_polys
  338. return data