RecCollateFn.py 3.1 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586
  1. # -*- coding: utf-8 -*-
  2. # @Time : 2020/6/16 17:06
  3. # @Author : zhoujun
  4. import torch
  5. import numpy as np
  6. import cv2
  7. from torchvision import transforms
  8. class Resize:
  9. def __init__(self, img_h, img_w, pad=True, **kwargs):
  10. self.img_h = img_h
  11. self.img_w = img_w
  12. self.pad = pad
  13. def __call__(self, img: np.ndarray):
  14. """
  15. 对图片进行处理,先按照高度进行resize,resize之后如果宽度不足指定宽度,就补黑色像素,否则就强行缩放到指定宽度
  16. :param img_path: 图片地址
  17. :return: 处理为指定宽高的图片
  18. """
  19. img_h = self.img_h
  20. img_w = self.img_w
  21. h, w = img.shape[:2]
  22. ratio_h = self.img_h / h
  23. new_w = int(w * ratio_h)
  24. if new_w < img_w and self.pad:
  25. img = cv2.resize(img, (new_w, img_h))
  26. if len(img.shape) == 2:
  27. img = np.expand_dims(img, 2)
  28. step = np.zeros((img_h, img_w - new_w, img.shape[-1]), dtype=img.dtype)
  29. img = np.column_stack((img, step))
  30. else:
  31. img = cv2.resize(img, (img_w, img_h))
  32. if len(img.shape) == 2:
  33. img = np.expand_dims(img, 2)
  34. if img.shape[-1] == 1:
  35. img = img[:, :, 0]
  36. return img
  37. class RecCollateFn:
  38. def __init__(self, *args, **kwargs):
  39. self.process = kwargs['dataset'].process
  40. self.t = transforms.ToTensor()
  41. def __call__(self, batch):
  42. resize_images = []
  43. all_same_height_images = [self.process.resize_with_specific_height(_['img']) for _ in batch]
  44. max_img_w = max({m_img.shape[1] for m_img in all_same_height_images})
  45. # make sure max_img_w is integral multiple of 8
  46. max_img_w = int(np.ceil(max_img_w / 8) * 8)
  47. labels = []
  48. for i in range(len(batch)):
  49. _label = batch[i]['label']
  50. labels.append(_label)
  51. img = self.process.normalize_img(all_same_height_images[i])
  52. img = self.process.width_pad_img(img, max_img_w)
  53. img = img.transpose([2, 0, 1])
  54. resize_images.append(torch.tensor(img, dtype=torch.float))
  55. resize_images = torch.stack(resize_images)
  56. return {'img': resize_images, 'label': labels}
  57. class RecCollateFnWithResize:
  58. """
  59. 将图片resize到固定宽度的RecCollateFn
  60. """
  61. def __init__(self, *args, **kwargs):
  62. from torchvision import transforms
  63. self.img_h = kwargs.get('img_h', 32)
  64. self.img_w = kwargs.get('img_w', 320)
  65. self.pad = kwargs.get('pad', True)
  66. self.t = transforms.ToTensor()
  67. def __call__(self, batch):
  68. resize_images = []
  69. resize_image_class = Resize(self.img_h, self.img_w, self.pad)
  70. labels = []
  71. for data in batch:
  72. labels.append(data['label'])
  73. resize_image = resize_image_class(data['img'])
  74. resize_image = self.t(resize_image)
  75. resize_images.append(resize_image)
  76. resize_images = torch.cat([t.unsqueeze(0) for t in resize_images], 0)
  77. return {'img':resize_images,'label':labels}