from __future__ import print_function, absolute_import import torch.utils.data as data import os import numpy as np import cv2 class _360CC(data.Dataset): def __init__(self, config, is_train=True): self.root = config.DATASET.ROOT self.is_train = is_train self.inp_h = config.MODEL.IMAGE_SIZE.H self.inp_w = config.MODEL.IMAGE_SIZE.W self.dataset_name = config.DATASET.DATASET self.mean = np.array(config.DATASET.MEAN, dtype=np.float32) self.std = np.array(config.DATASET.STD, dtype=np.float32) char_file = config.DATASET.CHAR_FILE with open(char_file, 'rb') as file: char_dict = {num: char.strip().decode('gbk', 'ignore') for num, char in enumerate(file.readlines())} txt_file = config.DATASET.JSON_FILE['train'] if is_train else config.DATASET.JSON_FILE['val'] # convert name:indices to name:string self.labels = [] with open(txt_file, 'r', encoding='utf-8') as file: contents = file.readlines() for c in contents: imgname = c.split(' ')[0] indices = c.split(' ')[1:] string = ''.join([char_dict[int(idx)] for idx in indices]) self.labels.append({imgname: string}) print("load {} images!".format(self.__len__())) def __len__(self): return len(self.labels) # def __getitem__(self, idx): # # img_name = list(self.labels[idx].keys())[0] # img = cv2.imread(os.path.join(self.root, img_name)) # img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) # # img_h, img_w = img.shape # # img = cv2.resize(img, (0,0), fx=self.inp_w / img_w, fy=self.inp_h / img_h, interpolation=cv2.INTER_CUBIC) # img = np.reshape(img, (self.inp_h, self.inp_w, 1)) # # img = img.astype(np.float32) # img = (img/255. - self.mean) / self.std # img = img.transpose([2, 0, 1]) # # return img, idx def __getitem__(self, idx): img_name = list(self.labels[idx].keys())[0] img = cv2.imread(os.path.join(self.root, img_name)) # img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) # 数据增强 img = img_process(img) img_h, img_w = img.shape[:2] # img = cv2.resize(img, (0,0), fx=self.inp_w / img_w, fy=self.inp_h / img_h, interpolation=cv2.INTER_CUBIC) # img = np.reshape(img, (self.inp_h, self.inp_w, 1)) # 图片真实长宽比 ratio = img_w / float(img_h) # 按比例缩放 if math.ceil(self.inp_h * ratio) > self.inp_w: # 如大于默认宽度,则宽度为imgW resized_w = self.inp_w else: # 如小于默认宽度则以图片真实宽为准 resized_w = int(math.ceil(self.inp_h * ratio)) # 缩放 img = cv2.resize(img, (resized_w, self.inp_h)) img = img.astype(np.float32) # 标准化 img = (img/255. - self.mean) / self.std img = img.transpose([2, 0, 1]) # 对宽度不足的位置,补0 padding_im = np.zeros((3, self.inp_h, self.inp_w), dtype=np.float32) # padding_im = np.full((3, self.inp_h, self.inp_w),255, dtype=np.float32) padding_im[:, :, 0:resized_w] = img # cv2.imshow('3', padding_im.transpose((1, 2, 0))) # cv2.waitKey(3000) return padding_im, idx import math def resize_norm_img(img): """ 数据缩放和归一化 :param img: 输入图片 """ # 默认输入尺寸 英文为 (3,32,100) imgC = 3 imgH = 32 # 高度 32 效果比较好 imgW = 320 # 中文每个字长宽比为1:1,防止长宽比在resize后被压缩过小或拉伸过大,识别大概10个字左右,所以宽度 10倍 = 320 # 图片的真实高宽 h, w = img.shape[:2] # 图片真实长宽比 ratio = w / float(h) # 按比例缩放 if math.ceil(imgH * ratio) > imgW: # 如大于默认宽度,则宽度为imgW resized_w = imgW else: # 如小于默认宽度则以图片真实宽为准 resized_w = int(math.ceil(imgH * ratio)) # 缩放 resized_image = cv2.resize(img, (resized_w, imgH)) resized_image = resized_image.astype('float32') # 归一化 resized_image = resized_image.transpose((2, 0, 1)) / 255 resized_image -= 0.5 resized_image /= 0.5 # 对宽度不足的位置,补0 padding_im = np.zeros((imgC, imgH, imgW), dtype=np.float32) padding_im[:, :, 0:resized_w] = resized_image # 转置 padding 后的图片用于可视化 draw_img = padding_im.transpose((1,2,0)) return padding_im, draw_img # 噪点增加 def add_noise(img): for i in range(20): #添加点噪声 temp_x = np.random.randint(0,img.shape[0]) temp_y = np.random.randint(0,img.shape[1]) img[temp_x][temp_y] = 255 return img # 适当腐蚀 def add_erode(img): kernel = cv2.getStructuringElement(cv2.MORPH_RECT,(3, 3)) img = cv2.erode(img,kernel) return img # 适当膨胀 def add_dilate(img): kernel = cv2.getStructuringElement(cv2.MORPH_RECT,(3, 3)) img = cv2.dilate(img,kernel) return img import random # 随机扰动 # def do(self,img_list=[]): # aug_list= copy.deepcopy(img_list) # for i in range(len(img_list)): # im = img_list[i] # if self.noise and random.random()<0.5: # im = self.add_noise(im) # if self.dilate and random.random()<0.25: # im = self.add_dilate(im) # if self.erode and random.random()<0.25: # im = self.add_erode(im) # aug_list.append(im) # return aug_list def img_process(img): if random.random() < 0.5: img = add_noise(img) if random.random()<0.25: img = add_dilate(img) if random.random()<0.25: img = add_erode(img) return img