_360cc.py 5.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184
  1. from __future__ import print_function, absolute_import
  2. import torch.utils.data as data
  3. import os
  4. import numpy as np
  5. import cv2
  6. class _360CC(data.Dataset):
  7. def __init__(self, config, is_train=True):
  8. self.root = config.DATASET.ROOT
  9. self.is_train = is_train
  10. self.inp_h = config.MODEL.IMAGE_SIZE.H
  11. self.inp_w = config.MODEL.IMAGE_SIZE.W
  12. self.dataset_name = config.DATASET.DATASET
  13. self.mean = np.array(config.DATASET.MEAN, dtype=np.float32)
  14. self.std = np.array(config.DATASET.STD, dtype=np.float32)
  15. char_file = config.DATASET.CHAR_FILE
  16. with open(char_file, 'rb') as file:
  17. char_dict = {num: char.strip().decode('gbk', 'ignore') for num, char in enumerate(file.readlines())}
  18. txt_file = config.DATASET.JSON_FILE['train'] if is_train else config.DATASET.JSON_FILE['val']
  19. # convert name:indices to name:string
  20. self.labels = []
  21. with open(txt_file, 'r', encoding='utf-8') as file:
  22. contents = file.readlines()
  23. for c in contents:
  24. imgname = c.split(' ')[0]
  25. indices = c.split(' ')[1:]
  26. string = ''.join([char_dict[int(idx)] for idx in indices])
  27. self.labels.append({imgname: string})
  28. print("load {} images!".format(self.__len__()))
  29. def __len__(self):
  30. return len(self.labels)
  31. # def __getitem__(self, idx):
  32. #
  33. # img_name = list(self.labels[idx].keys())[0]
  34. # img = cv2.imread(os.path.join(self.root, img_name))
  35. # img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
  36. #
  37. # img_h, img_w = img.shape
  38. #
  39. # img = cv2.resize(img, (0,0), fx=self.inp_w / img_w, fy=self.inp_h / img_h, interpolation=cv2.INTER_CUBIC)
  40. # img = np.reshape(img, (self.inp_h, self.inp_w, 1))
  41. #
  42. # img = img.astype(np.float32)
  43. # img = (img/255. - self.mean) / self.std
  44. # img = img.transpose([2, 0, 1])
  45. #
  46. # return img, idx
  47. def __getitem__(self, idx):
  48. img_name = list(self.labels[idx].keys())[0]
  49. img = cv2.imread(os.path.join(self.root, img_name))
  50. # img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
  51. # 数据增强
  52. img = img_process(img)
  53. img_h, img_w = img.shape[:2]
  54. # img = cv2.resize(img, (0,0), fx=self.inp_w / img_w, fy=self.inp_h / img_h, interpolation=cv2.INTER_CUBIC)
  55. # img = np.reshape(img, (self.inp_h, self.inp_w, 1))
  56. # 图片真实长宽比
  57. ratio = img_w / float(img_h)
  58. # 按比例缩放
  59. if math.ceil(self.inp_h * ratio) > self.inp_w:
  60. # 如大于默认宽度,则宽度为imgW
  61. resized_w = self.inp_w
  62. else:
  63. # 如小于默认宽度则以图片真实宽为准
  64. resized_w = int(math.ceil(self.inp_h * ratio))
  65. # 缩放
  66. img = cv2.resize(img, (resized_w, self.inp_h))
  67. img = img.astype(np.float32)
  68. # 标准化
  69. img = (img/255. - self.mean) / self.std
  70. img = img.transpose([2, 0, 1])
  71. # 对宽度不足的位置,补0
  72. padding_im = np.zeros((3, self.inp_h, self.inp_w), dtype=np.float32)
  73. # padding_im = np.full((3, self.inp_h, self.inp_w),255, dtype=np.float32)
  74. padding_im[:, :, 0:resized_w] = img
  75. # cv2.imshow('3', padding_im.transpose((1, 2, 0)))
  76. # cv2.waitKey(3000)
  77. return padding_im, idx
  78. import math
  79. def resize_norm_img(img):
  80. """
  81. 数据缩放和归一化
  82. :param img: 输入图片
  83. """
  84. # 默认输入尺寸 英文为 (3,32,100)
  85. imgC = 3
  86. imgH = 32 # 高度 32 效果比较好
  87. imgW = 320 # 中文每个字长宽比为1:1,防止长宽比在resize后被压缩过小或拉伸过大,识别大概10个字左右,所以宽度 10倍 = 320
  88. # 图片的真实高宽
  89. h, w = img.shape[:2]
  90. # 图片真实长宽比
  91. ratio = w / float(h)
  92. # 按比例缩放
  93. if math.ceil(imgH * ratio) > imgW:
  94. # 如大于默认宽度,则宽度为imgW
  95. resized_w = imgW
  96. else:
  97. # 如小于默认宽度则以图片真实宽为准
  98. resized_w = int(math.ceil(imgH * ratio))
  99. # 缩放
  100. resized_image = cv2.resize(img, (resized_w, imgH))
  101. resized_image = resized_image.astype('float32')
  102. # 归一化
  103. resized_image = resized_image.transpose((2, 0, 1)) / 255
  104. resized_image -= 0.5
  105. resized_image /= 0.5
  106. # 对宽度不足的位置,补0
  107. padding_im = np.zeros((imgC, imgH, imgW), dtype=np.float32)
  108. padding_im[:, :, 0:resized_w] = resized_image
  109. # 转置 padding 后的图片用于可视化
  110. draw_img = padding_im.transpose((1,2,0))
  111. return padding_im, draw_img
  112. # 噪点增加
  113. def add_noise(img):
  114. for i in range(20): #添加点噪声
  115. temp_x = np.random.randint(0,img.shape[0])
  116. temp_y = np.random.randint(0,img.shape[1])
  117. img[temp_x][temp_y] = 255
  118. return img
  119. # 适当腐蚀
  120. def add_erode(img):
  121. kernel = cv2.getStructuringElement(cv2.MORPH_RECT,(3, 3))
  122. img = cv2.erode(img,kernel)
  123. return img
  124. # 适当膨胀
  125. def add_dilate(img):
  126. kernel = cv2.getStructuringElement(cv2.MORPH_RECT,(3, 3))
  127. img = cv2.dilate(img,kernel)
  128. return img
  129. import random
  130. # 随机扰动
  131. # def do(self,img_list=[]):
  132. # aug_list= copy.deepcopy(img_list)
  133. # for i in range(len(img_list)):
  134. # im = img_list[i]
  135. # if self.noise and random.random()<0.5:
  136. # im = self.add_noise(im)
  137. # if self.dilate and random.random()<0.25:
  138. # im = self.add_dilate(im)
  139. # if self.erode and random.random()<0.25:
  140. # im = self.add_erode(im)
  141. # aug_list.append(im)
  142. # return aug_list
  143. def img_process(img):
  144. if random.random() < 0.5:
  145. img = add_noise(img)
  146. if random.random()<0.25:
  147. img = add_dilate(img)
  148. if random.random()<0.25:
  149. img = add_erode(img)
  150. return img