CreateRecAug.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473
  1. import cv2
  2. import abc
  3. import random
  4. import numpy as np
  5. from PIL import Image, ImageEnhance, ImageFilter, ImageOps, ImageDraw
  6. def cv2pil(image):
  7. """
  8. 将bgr格式的numpy的图像转换为pil
  9. :param image: 图像数组
  10. :return: Image对象
  11. """
  12. assert isinstance(image, np.ndarray), 'input image type is not cv2'
  13. if len(image.shape) == 2:
  14. return Image.fromarray(image)
  15. elif len(image.shape) == 3:
  16. return Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGR2RGB))
  17. def get_pil_image(image):
  18. """
  19. 将图像统一转换为PIL格式
  20. :param image: 图像
  21. :return: Image格式的图像
  22. """
  23. if isinstance(image, Image.Image): # or isinstance(image, PIL.JpegImagePlugin.JpegImageFile):
  24. return image
  25. elif isinstance(image, np.ndarray):
  26. return cv2pil(image)
  27. def get_cv_image(image):
  28. """
  29. 将图像转换为numpy格式的数据
  30. :param image: 图像
  31. :return: ndarray格式的图像数据
  32. """
  33. if isinstance(image, np.ndarray):
  34. return image
  35. elif isinstance(image, Image.Image): # or isinstance(image, PIL.JpegImagePlugin.JpegImageFile):
  36. return pil2cv(image)
  37. def pil2cv(image):
  38. """
  39. 将Image对象转换为ndarray格式图像
  40. :param image: 图像对象
  41. :return: ndarray图像数组
  42. """
  43. if len(image.split()) == 1:
  44. return np.asarray(image)
  45. elif len(image.split()) == 3:
  46. return cv2.cvtColor(np.asarray(image), cv2.COLOR_RGB2BGR)
  47. elif len(image.split()) == 4:
  48. return cv2.cvtColor(np.asarray(image), cv2.COLOR_RGBA2BGR)
  49. class TransBase(object):
  50. """
  51. 数据增广的基类
  52. """
  53. def __init__(self, probability=1.):
  54. """
  55. 初始化对象
  56. :param probability: 执行概率
  57. """
  58. super(TransBase, self).__init__()
  59. self.probability = probability
  60. @abc.abstractmethod
  61. def trans_function(self, _image):
  62. """
  63. 初始化执行函数,需要进行重载
  64. :param _image: 待处理图像
  65. :return: 执行后的Image对象
  66. """
  67. pass
  68. # @utils.zlog
  69. def process(self, _image):
  70. """
  71. 调用执行函数
  72. :param _image: 待处理图像
  73. :return: 执行后的Image对象
  74. """
  75. if np.random.random() < self.probability:
  76. return self.trans_function(_image)
  77. else:
  78. return _image
  79. def __call__(self, _image):
  80. """
  81. 重载(),方便直接进行调用
  82. :param _image: 待处理图像
  83. :return: 执行后的Image
  84. """
  85. return self.process(_image)
  86. class RandomContrast(TransBase):
  87. """
  88. 随机对比度
  89. """
  90. def setparam(self, lower=0.5, upper=1.5):
  91. self.lower = lower
  92. self.upper = upper
  93. assert self.upper >= self.lower, "upper must be >= lower."
  94. assert self.lower >= 0, "lower must be non-negative."
  95. def trans_function(self, _image):
  96. _image = get_pil_image(_image)
  97. contrast_enhance = ImageEnhance.Contrast(_image)
  98. return contrast_enhance.enhance(random.uniform(self.lower, self.upper))
  99. class RandomLine(TransBase):
  100. """
  101. 在图像增加一条简单的随机线
  102. """
  103. def trans_function(self, image):
  104. image = get_pil_image(image)
  105. draw = ImageDraw.Draw(image)
  106. h = image.height
  107. w = image.width
  108. y0 = random.randint(h // 4, h * 3 // 4)
  109. y1 = np.clip(random.randint(-3, 3) + y0, 0, h - 1)
  110. color = random.randint(0, 30)
  111. draw.line(((0, y0), (w - 1, y1)), fill=(color, color, color), width=2)
  112. return image
  113. class RandomBrightness(TransBase):
  114. """
  115. 随机对比度
  116. """
  117. def setparam(self, lower=0.5, upper=1.5):
  118. self.lower = lower
  119. self.upper = upper
  120. assert self.upper >= self.lower, "upper must be >= lower."
  121. assert self.lower >= 0, "lower must be non-negative."
  122. def trans_function(self, image):
  123. image = get_pil_image(image)
  124. bri = ImageEnhance.Brightness(image)
  125. return bri.enhance(random.uniform(self.lower, self.upper))
  126. class RandomColor(TransBase):
  127. """
  128. 随机色彩平衡
  129. """
  130. def setparam(self, lower=0.5, upper=1.5):
  131. self.lower = lower
  132. self.upper = upper
  133. assert self.upper >= self.lower, "upper must be >= lower."
  134. assert self.lower >= 0, "lower must be non-negative."
  135. def trans_function(self, image):
  136. image = get_pil_image(image)
  137. col = ImageEnhance.Color(image)
  138. return col.enhance(random.uniform(self.lower, self.upper))
  139. class RandomSharpness(TransBase):
  140. """
  141. 随机锐度
  142. """
  143. def setparam(self, lower=0.5, upper=1.5):
  144. self.lower = lower
  145. self.upper = upper
  146. assert self.upper >= self.lower, "upper must be >= lower."
  147. assert self.lower >= 0, "lower must be non-negative."
  148. def trans_function(self, image):
  149. image = get_pil_image(image)
  150. sha = ImageEnhance.Sharpness(image)
  151. return sha.enhance(random.uniform(self.lower, self.upper))
  152. class Compress(TransBase):
  153. """
  154. 随机压缩率,利用jpeg的有损压缩来增广
  155. """
  156. def setparam(self, lower=5, upper=85):
  157. self.lower = lower
  158. self.upper = upper
  159. assert self.upper >= self.lower, "upper must be >= lower."
  160. assert self.lower >= 0, "lower must be non-negative."
  161. def trans_function(self, image):
  162. img = get_cv_image(image)
  163. param = [int(cv2.IMWRITE_JPEG_QUALITY), random.randint(self.lower, self.upper)]
  164. img_encode = cv2.imencode('.jpeg', img, param)
  165. img_decode = cv2.imdecode(img_encode[1], cv2.IMREAD_COLOR)
  166. pil_img = cv2pil(img_decode)
  167. if len(image.split()) == 1:
  168. pil_img = pil_img.convert('L')
  169. return pil_img
  170. class Exposure(TransBase):
  171. """
  172. 随机区域曝光
  173. """
  174. def setparam(self, lower=5, upper=10):
  175. self.lower = lower
  176. self.upper = upper
  177. assert self.upper >= self.lower, "upper must be >= lower."
  178. assert self.lower >= 0, "lower must be non-negative."
  179. def trans_function(self, image):
  180. image = get_cv_image(image)
  181. h, w = image.shape[:2]
  182. x0 = random.randint(0, w)
  183. y0 = random.randint(0, h)
  184. x1 = random.randint(x0, w)
  185. y1 = random.randint(y0, h)
  186. transparent_area = (x0, y0, x1, y1)
  187. mask = Image.new('L', (w, h), color=255)
  188. draw = ImageDraw.Draw(mask)
  189. mask = np.array(mask)
  190. if len(image.shape) == 3:
  191. mask = mask[:, :, np.newaxis]
  192. mask = np.concatenate([mask, mask, mask], axis=2)
  193. draw.rectangle(transparent_area, fill=random.randint(150, 255))
  194. reflection_result = image + (255 - mask)
  195. reflection_result = np.clip(reflection_result, 0, 255)
  196. return cv2pil(reflection_result)
  197. class Rotate(TransBase):
  198. """
  199. 随机旋转
  200. """
  201. def setparam(self, lower=-5, upper=5):
  202. self.lower = lower
  203. self.upper = upper
  204. assert self.upper >= self.lower, "upper must be >= lower."
  205. def trans_function(self, image):
  206. image = get_pil_image(image)
  207. rot = random.uniform(self.lower, self.upper)
  208. trans_img = image.rotate(rot, expand=True)
  209. return trans_img
  210. class Blur(TransBase):
  211. """
  212. 随机高斯模糊
  213. """
  214. def setparam(self, lower=0, upper=1):
  215. self.lower = lower
  216. self.upper = upper
  217. assert self.upper >= self.lower, "upper must be >= lower."
  218. assert self.lower >= 0, "lower must be non-negative."
  219. def trans_function(self, image):
  220. image = get_pil_image(image)
  221. image = image.filter(ImageFilter.GaussianBlur(radius=1.5))
  222. return image
  223. class MotionBlur(TransBase):
  224. """
  225. 随机运动模糊
  226. """
  227. def setparam(self, degree=5, angle=180):
  228. self.degree = degree
  229. self.angle = angle
  230. def trans_function(self, image):
  231. image = get_pil_image(image)
  232. angle = random.randint(0, self.angle)
  233. M = cv2.getRotationMatrix2D((self.degree / 2, self.degree / 2), angle, 1)
  234. motion_blur_kernel = np.diag(np.ones(self.degree))
  235. motion_blur_kernel = cv2.warpAffine(motion_blur_kernel, M, (self.degree, self.degree))
  236. motion_blur_kernel = motion_blur_kernel / self.degree
  237. image = image.filter(ImageFilter.Kernel(size=(self.degree, self.degree), kernel=motion_blur_kernel.reshape(-1)))
  238. return image
  239. class Salt(TransBase):
  240. """
  241. 随机椒盐噪音
  242. """
  243. def setparam(self, rate=0.02):
  244. self.rate = rate
  245. def trans_function(self, image):
  246. image = get_pil_image(image)
  247. num_noise = int(image.size[1] * image.size[0] * self.rate)
  248. # assert len(image.split()) == 1
  249. for k in range(num_noise):
  250. i = int(np.random.random() * image.size[1])
  251. j = int(np.random.random() * image.size[0])
  252. image.putpixel((j, i), int(np.random.random() * 255))
  253. return image
  254. class AdjustResolution(TransBase):
  255. """
  256. 随机分辨率
  257. """
  258. def setparam(self, max_rate=0.95, min_rate=0.5):
  259. self.max_rate = max_rate
  260. self.min_rate = min_rate
  261. def trans_function(self, image):
  262. image = get_pil_image(image)
  263. w, h = image.size
  264. rate = np.random.random() * (self.max_rate - self.min_rate) + self.min_rate
  265. w2 = int(w * rate)
  266. h2 = int(h * rate)
  267. image = image.resize((w2, h2))
  268. image = image.resize((w, h))
  269. return image
  270. class Crop(TransBase):
  271. """
  272. 随机抠图,并且抠图区域透视变换为原图大小
  273. """
  274. def setparam(self, maxv=2):
  275. self.maxv = maxv
  276. def trans_function(self, image):
  277. img = get_cv_image(image)
  278. h, w = img.shape[:2]
  279. org = np.array([[0, np.random.randint(0, self.maxv)],
  280. [w, np.random.randint(0, self.maxv)],
  281. [0, h - np.random.randint(0, self.maxv)],
  282. [w, h - np.random.randint(0, self.maxv)]], np.float32)
  283. dst = np.array([[0, 0], [w, 0], [0, h], [w, h]], np.float32)
  284. M = cv2.getPerspectiveTransform(org, dst)
  285. res = cv2.warpPerspective(img, M, (w, h))
  286. return get_pil_image(res)
  287. class Crop2(TransBase):
  288. """
  289. 随机抠图,并且抠图区域透视变换为原图大小
  290. """
  291. def setparam(self, maxv_h=4, maxv_w=4):
  292. self.maxv_h = maxv_h
  293. self.maxv_w = maxv_w
  294. def trans_function(self, image_and_loc):
  295. image, left, top, right, bottom = image_and_loc
  296. w, h = image.size
  297. left = np.clip(left, 0, w - 1)
  298. right = np.clip(right, 0, w - 1)
  299. top = np.clip(top, 0, h - 1)
  300. bottom = np.clip(bottom, 0, h - 1)
  301. img = get_cv_image(image)
  302. try:
  303. res = get_pil_image(img[top:bottom, left:right])
  304. return res
  305. except AttributeError as e:
  306. print('error')
  307. image.save('test_imgs/t.png')
  308. print(left, top, right, bottom)
  309. h = bottom - top
  310. w = right - left
  311. org = np.array(
  312. [[left - np.random.randint(0, self.maxv_w), top + np.random.randint(-self.maxv_h, self.maxv_h // 2)],
  313. [right + np.random.randint(0, self.maxv_w), top + np.random.randint(-self.maxv_h, self.maxv_h // 2)],
  314. [left - np.random.randint(0, self.maxv_w), bottom - np.random.randint(-self.maxv_h, self.maxv_h // 2)],
  315. [right + np.random.randint(0, self.maxv_w), bottom - np.random.randint(-self.maxv_h, self.maxv_h // 2)]],
  316. np.float32)
  317. dst = np.array([[0, 0], [w, 0], [0, h], [w, h]], np.float32)
  318. M = cv2.getPerspectiveTransform(org, dst)
  319. res = cv2.warpPerspective(img, M, (w, h))
  320. return get_pil_image(res)
  321. class Stretch(TransBase):
  322. """
  323. 随机图像横向拉伸
  324. """
  325. def setparam(self, max_rate=1.2, min_rate=0.8):
  326. self.max_rate = max_rate
  327. self.min_rate = min_rate
  328. def trans_function(self, image):
  329. image = get_pil_image(image)
  330. w, h = image.size
  331. rate = np.random.random() * (self.max_rate - self.min_rate) + self.min_rate
  332. w2 = int(w * rate)
  333. image = image.resize((w2, h))
  334. return image
  335. class DataAug:
  336. def __init__(self):
  337. self.crop = Crop(probability=0.1)
  338. self.crop2 = Crop2(probability=1.1)
  339. self.random_contrast = RandomContrast(probability=0.1)
  340. self.random_brightness = RandomBrightness(probability=0.1)
  341. self.random_color = RandomColor(probability=0.1)
  342. self.random_sharpness = RandomSharpness(probability=0.1)
  343. self.compress = Compress(probability=0.3)
  344. self.exposure = Exposure(probability=0.1)
  345. self.rotate = Rotate(probability=0.1)
  346. self.blur = Blur(probability=0.3)
  347. self.motion_blur = MotionBlur(probability=0.3)
  348. self.salt = Salt(probability=0.1)
  349. self.adjust_resolution = AdjustResolution(probability=0.1)
  350. self.stretch = Stretch(probability=0.1)
  351. self.random_line = RandomLine(probability=0.3)
  352. self.crop.setparam()
  353. self.crop2.setparam()
  354. self.random_contrast.setparam()
  355. self.random_brightness.setparam()
  356. self.random_color.setparam()
  357. self.random_sharpness.setparam()
  358. self.compress.setparam()
  359. self.exposure.setparam()
  360. self.rotate.setparam()
  361. self.blur.setparam()
  362. self.motion_blur.setparam()
  363. self.salt.setparam()
  364. self.adjust_resolution.setparam()
  365. self.stretch.setparam()
  366. def aug_img(self, img):
  367. img = self.crop.process(img)
  368. img = self.random_contrast.process(img)
  369. img = self.random_brightness.process(img)
  370. img = self.random_color.process(img)
  371. img = self.random_sharpness.process(img)
  372. img = self.random_line.process(img)
  373. if img.size[1] >= 32:
  374. img = self.compress.process(img)
  375. img = self.adjust_resolution.process(img)
  376. img = self.motion_blur.process(img)
  377. img = self.blur.process(img)
  378. img = self.exposure.process(img)
  379. img = self.rotate.process(img)
  380. img = self.salt.process(img)
  381. img = self.inverse_color(img)
  382. img = self.stretch.process(img)
  383. return img
  384. def inverse_color(self, image):
  385. if np.random.random() < 0.4:
  386. image = ImageOps.invert(image)
  387. return image
  388. if __name__ == '__main__':
  389. img = Image.open('./test_image.png')
  390. data_augment = DataAug()
  391. augmented_img = data_augment.aug_img(img)
  392. augmented_img.show()