123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461 |
- # -*-coding:utf-8-*-
- """
- @Author: Jeffery Sheng (Zhenfei Sheng)
- @Time: 2020/5/21 19:44
- @File: RecDataSet.py
- """
- import six
- import cv2
- import torch
- import numpy as np
- from PIL import Image
- from torch.utils.data import Dataset, DataLoader
- from torchocr.utils.CreateRecAug import cv2pil, pil2cv, RandomBrightness, RandomContrast, \
- RandomLine, RandomSharpness, Compress, Rotate, \
- Blur, MotionBlur, Salt, AdjustResolution
- import re
- class RecTextLineDataset(Dataset):
- def __init__(self, config):
- """
- 文本行 DataSet, 用于处理标注格式为 `img_path\tlabel` 的标注格式
- :param config: 相关配置,一般为 config['dataset']['train']['dataset] or config['dataset']['eval']['dataset]
- 其主要应包含如下字段: file: 标注文件路径
- input_h: 图片的目标高
- mean: 归一化均值
- std: 归一化方差
- augmentation: 使用使用数据增强
- :return None
- """
- self.augmentation = config.augmentation
- self.process = RecDataProcess(config)
- with open(config.alphabet, 'r', encoding='utf-8') as file:
- alphabet = ''.join([s.strip('\n') for s in file.readlines()])
- # alphabet += ' '
- alphabet = alphabet.replace("blank"," ") #add
- self.str2idx = {c: i for i, c in enumerate(alphabet)}
- self.labels = []
- # if "test.txt" in config.file:
- with open(config.file, 'r', encoding='utf-8') as f_reader:
- for m_line in f_reader.readlines():
- m_line=m_line.strip()
- params = m_line.split(' ')
- # print(params)
- if len(params) >= 2:
- m_image_name = params[0]
- m_image_name = '/data2/znj/CRNN_Chinese_Characters_Rec/data/data/python_znj/Lets_OCR/recognizer/crnn/data/images/'+m_image_name
- m_gt_text = params[1:]
- # print(m_gt_text)
- m_gt_text = "".join([alphabet[int(idx)] if int(idx)<len(alphabet) else '#none#' for idx in m_gt_text])
- # if True in [c not in self.str2idx for c in m_gt_text]:
- # continue
- if "#none#" in m_gt_text:
- continue
- self.labels.append((m_image_name, m_gt_text))
- # with open(config.file, 'r', encoding='utf-8') as f_reader:
- # for m_line in f_reader.readlines():
- # is_skip = False
- # m_line=m_line.strip()
- # params = m_line.split(' ')
- # if len(params) >= 2:
- # m_image_name = params[0]
- # m_image_name = '/data2/znj/text_renderer/output2/default/' + m_image_name + '.jpg'
- # m_gt_text = params[1]
- # for w in m_gt_text:
- # if w not in alphabet:
- # is_skip = True
- # break
- # if is_skip:
- # continue
- # self.labels.append((m_image_name, m_gt_text))
- # add
- # if "train.txt" in config.file:
- # with open("/data2/znj/PytorchOCR/data/train.txt", 'r', encoding='utf-8') as f_reader:
- # for m_line in f_reader.readlines():
- # m_line=m_line.strip()
- # params = m_line.split('\t')
- # # print(params)
- # if len(params) >= 3:
- # m_image_name = params[2]
- # m_image_name = '/data2/znj/PytorchOCR/data/image/'+m_image_name
- # m_gt_text = params[3]
- # # print(m_gt_text)
- # # m_gt_text = "".join([alphabet[int(idx)] if int(idx)<len(alphabet) else '#none#' for idx in m_gt_text])
- # # if True in [c not in self.str2idx for c in m_gt_text]:
- # # continue
- # _m_gt_text = "".join([str(self.str2idx.get(i,'#none#')) for i in m_gt_text])
- # if "#none#" in _m_gt_text:
- # continue
- # self.labels.append((m_image_name, m_gt_text))
- # add
- # if "train.txt" in config.file:
- # with open("/data2/znj/text_renderer/output3/default/tmp_labels.txt", 'r', encoding='utf-8') as f_reader:
- # for m_line in f_reader.readlines():
- # is_skip = False
- # m_line=m_line.strip()
- # params = m_line.split(' ')
- # # print(params)
- # if len(params) >= 2:
- # m_image_name = params[0]
- # m_image_name = '/data2/znj/text_renderer/output3/default/' + m_image_name + '.jpg'
- # m_gt_text = params[1]
- # for w in m_gt_text:
- # if w not in alphabet:
- # is_skip = True
- # break
- # if is_skip:
- # continue
- # self.labels.append((m_image_name, m_gt_text))
- # test
- with open(config.file, 'r', encoding='utf-8') as f_reader:
- for m_line in f_reader.readlines():
- m_line=m_line.strip()
- params = m_line.split(' ')
- # print(params)
- if len(params) >= 2:
- m_image_name = params[0]
- m_gt_text = params[1:]
- # print(m_gt_text)
- m_gt_text = "".join([alphabet[int(idx)] if int(idx)<len(alphabet) else '#none#' for idx in m_gt_text])
- # if True in [c not in self.str2idx for c in m_gt_text]:
- # continue
- if "#none#" in m_gt_text:
- continue
- # self.labels.append((m_image_name, m_gt_text))
- self.labels.append((m_image_name, 'a123'))
- print(self.labels)
- # paddle 识别数据
- if "log4.log" in config.file:
- with open(config.file, 'r', encoding='utf-8') as f_reader:
- for m_line in f_reader.readlines():
- m_line=m_line.strip()
- iamge_path, line_split2 = re.split(" ", m_line, maxsplit=1)
- text, box = re.split(" \[\[", line_split2, maxsplit=1)
- box = '[[' + box
- box,_ = re.split(" rec_res:", box, maxsplit=1)
- self.labels.append((iamge_path, text,box))
- print(self.labels)
- def _find_max_length(self):
- return max({len(_[1]) for _ in self.labels})
- def __len__(self):
- return len(self.labels)
- def __getitem__(self, index):
- # get img_path and trans
- # img_path, trans = self.labels[index]
- label = self.labels[index]
- if len(label)==2:
- img_path, trans = label
- # read img
- img = cv2.imread(img_path)
- img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
- elif len(label)==3:
- img_path, trans,box = label
- # read img
- img = cv2.imread(img_path)
- img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
- bbox = eval(box)
- x1 = int(min([i[0] for i in bbox]))
- x2 = int(max([i[0] for i in bbox]))
- y1 = int(min([i[1] for i in bbox]))
- y2 = int(max([i[1] for i in bbox]))
- img = img[y1:y2, x1:x2]
- # do aug
- if len(label)==2:
- if self.augmentation:
- img = pil2cv(self.process.aug_img(cv2pil(img)))
- return {'img': img, 'label': trans}
- class RecTextLineDataset2(Dataset):
- def __init__(self, config):
- """
- 文本行 DataSet, 用于处理标注格式为 `img_path\tlabel` 的标注格式
- :param config: 相关配置,一般为 config['dataset']['train']['dataset] or config['dataset']['eval']['dataset]
- 其主要应包含如下字段: file: 标注文件路径
- input_h: 图片的目标高
- mean: 归一化均值
- std: 归一化方差
- augmentation: 使用使用数据增强
- :return None
- """
- self.augmentation = config.augmentation
- self.process = RecDataProcess(config)
- self.labels = []
- with open("", 'r', encoding='utf-8') as f_reader:
- for m_line in f_reader.readlines():
- m_line=m_line.strip()
- params = m_line.split(' ')
- if len(params) >= 2:
- m_image_name = params[0]
- m_gt_text = params[1:]
- m_gt_text = "".join([alphabet[int(idx)] if int(idx)<len(alphabet) else '#none#' for idx in m_gt_text])
- if "#none#" in m_gt_text:
- continue
- self.labels.append((m_image_name, 'a123'))
- print(self.labels)
- def _find_max_length(self):
- return max({len(_[1]) for _ in self.labels})
- def __len__(self):
- return len(self.labels)
- def __getitem__(self, index):
- # get img_path and trans
- img_path, trans = self.labels[index]
- # read img
- img = cv2.imread(img_path)
- img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
- # do aug
- if self.augmentation:
- img = pil2cv(self.process.aug_img(cv2pil(img)))
- return {'img': img, 'label': trans}
- class RecLmdbDataset(Dataset):
- def __init__(self, config):
- """
- Lmdb DataSet, 用于处理转换为 lmdb 文件后的数据集
- :param config: 相关配置,一般为 config['dataset']['train']['dataset] or config['dataset']['eval']['dataset]
- 其主要应包含如下字段: file: 标注文件路径
- input_h: 图片的目标高
- mean: 归一化均值
- std: 归一化方差
- augmentation: 使用使用数据增强
- :return None
- """
- import lmdb, sys
- self.env = lmdb.open(config.file, max_readers=32, readonly=True, lock=False, readahead=False, meminit=False)
- if not self.env:
- print('cannot create lmdb from %s' % (config.file))
- sys.exit(0)
- self.augmentation = config.augmentation
- self.process = RecDataProcess(config)
- self.filtered_index_list = []
- self.labels = []
- self.str2idx = {c: i for i, c in enumerate(config.alphabet)}
- with self.env.begin(write=False) as txn:
- nSamples = int(txn.get('num-samples'.encode()))
- self.nSamples = nSamples
- for index in range(self.nSamples):
- index += 1 # lmdb starts with 1
- label_key = 'label-%09d'.encode() % index
- label = txn.get(label_key).decode('utf-8')
- # todo 添加 过滤最长
- # if len(label) > config.max_len:
- # # print(f'The length of the label is longer than max_length: length
- # # {len(label)}, {label} in dataset {self.root}')
- # continue
- if True in [c not in self.str2idx for c in label]:
- continue
- # By default, images containing characters which are not in opt.character are filtered.
- # You can add [UNK] token to `opt.character` in utils.py instead of this filtering.
- self.labels.append(label)
- self.filtered_index_list.append(index)
- def _find_max_length(self):
- return max({len(_) for _ in self.labels})
- def __getitem__(self, index):
- index = self.filtered_index_list[index]
- with self.env.begin(write=False) as txn:
- label_key = 'label-%09d'.encode() % index
- label = txn.get(label_key).decode('utf-8')
- img_key = 'image-%09d'.encode() % index
- imgbuf = txn.get(img_key)
- buf = six.BytesIO()
- buf.write(imgbuf)
- buf.seek(0)
- img = Image.open(buf).convert('RGB') # for color image
- # We only train and evaluate on alphanumerics (or pre-defined character set in rec_train.py)
- img = np.array(img)
- if self.augmentation:
- img = pil2cv(self.process.aug_img(cv2pil(img)))
- return {'img': img, 'label': label}
- def __len__(self):
- return len(self.filtered_index_list)
- class RecDataLoader:
- def __init__(self, dataset, batch_size, shuffle, num_workers, **kwargs):
- """
- 自定义 DataLoader, 主要实现数据集的按长度划分,将长度相近的放在一个 batch
- :param dataset: 继承自 torch.utils.data.DataSet的类对象
- :param batch_size: 一个 batch 的图片数量
- :param shuffle: 是否打乱数据集
- :param num_workers: 后台进程数
- :param kwargs: **
- """
- self.dataset = dataset
- self.process = dataset.process
- self.len_thresh = self.dataset._find_max_length() // 2
- self.batch_size = batch_size
- self.shuffle = shuffle
- self.num_workers = num_workers
- self.iteration = 0
- self.dataiter = None
- self.queue_1 = list()
- self.queue_2 = list()
- def __len__(self):
- return len(self.dataset) // self.batch_size if len(self.dataset) % self.batch_size == 0 \
- else len(self.dataset) // self.batch_size + 1
- def __iter__(self):
- return self
- def pack(self, batch_data):
- batch = {'img': [], 'label': []}
- # img tensor current shape: B,H,W,C
- all_same_height_images = [self.process.resize_with_specific_height(_['img'][0].numpy()) for _ in batch_data]
- max_img_w = max({m_img.shape[1] for m_img in all_same_height_images})
- # make sure max_img_w is integral multiple of 8
- max_img_w = int(np.ceil(max_img_w / 8) * 8)
- for i in range(len(batch_data)):
- _label = batch_data[i]['label'][0]
- img = self.process.normalize_img(self.process.width_pad_img(all_same_height_images[i], max_img_w))
- img = img.transpose([2, 0, 1])
- batch['img'].append(torch.tensor(img, dtype=torch.float))
- batch['label'].append(_label)
- batch['img'] = torch.stack(batch['img'])
- return batch
- def build(self):
- self.dataiter = DataLoader(self.dataset, batch_size=1, shuffle=self.shuffle,
- num_workers=self.num_workers).__iter__()
- def __next__(self):
- if self.dataiter == None:
- self.build()
- if self.iteration == len(self.dataset) and len(self.queue_2):
- batch_data = self.queue_2
- self.queue_2 = list()
- return self.pack(batch_data)
- if not len(self.queue_2) and not len(self.queue_1) and self.iteration == len(self.dataset):
- self.iteration = 0
- self.dataiter = None
- raise StopIteration
- # start iteration
- try:
- while True:
- # get data from origin dataloader
- temp = self.dataiter.__next__()
- self.iteration += 1
- # to different queue
- if len(temp['label'][0]) <= self.len_thresh:
- self.queue_1.append(temp)
- else:
- self.queue_2.append(temp)
- # to store batch data
- batch_data = None
- # queue_1 full, push to batch_data
- if len(self.queue_1) == self.batch_size:
- batch_data = self.queue_1
- self.queue_1 = list()
- # or queue_2 full, push to batch_data
- elif len(self.queue_2) == self.batch_size:
- batch_data = self.queue_2
- self.queue_2 = list()
- # start to process batch
- if batch_data is not None:
- return self.pack(batch_data)
- # deal with last batch
- except StopIteration:
- if self.queue_1 == []:
- raise StopIteration
- batch_data = self.queue_1
- self.queue_1 = list()
- return self.pack(batch_data)
- class RecDataProcess:
- def __init__(self, config):
- """
- 文本是被数据增广类
- :param config: 配置,主要用到的字段有 input_h, mean, std
- """
- self.config = config
- self.random_contrast = RandomContrast(probability=0.3)
- self.random_brightness = RandomBrightness(probability=0.3)
- self.random_sharpness = RandomSharpness(probability=0.3)
- self.compress = Compress(probability=0.3)
- self.rotate = Rotate(probability=0.5)
- self.blur = Blur(probability=0.3)
- self.motion_blur = MotionBlur(probability=0.3)
- self.salt = Salt(probability=0.3)
- self.adjust_resolution = AdjustResolution(probability=0.3)
- self.random_line = RandomLine(probability=0.3)
- self.random_contrast.setparam()
- self.random_brightness.setparam()
- self.random_sharpness.setparam()
- self.compress.setparam()
- self.rotate.setparam()
- self.blur.setparam()
- self.motion_blur.setparam()
- self.salt.setparam()
- self.adjust_resolution.setparam()
- def aug_img(self, img):
- img = self.random_contrast.process(img)
- img = self.random_brightness.process(img)
- img = self.random_sharpness.process(img)
- img = self.random_line.process(img)
- if img.size[1] >= 32:
- img = self.compress.process(img)
- img = self.adjust_resolution.process(img)
- img = self.motion_blur.process(img)
- img = self.blur.process(img)
- img = self.rotate.process(img)
- img = self.salt.process(img)
- return img
- def resize_with_specific_height(self, _img):
- """
- 将图像resize到指定高度
- :param _img: 待resize的图像
- :return: resize完成的图像
- """
- resize_ratio = self.config.input_h / _img.shape[0]
- return cv2.resize(_img, (0, 0), fx=resize_ratio, fy=resize_ratio, interpolation=cv2.INTER_LINEAR)
- def normalize_img(self, _img):
- """
- 根据配置的均值和标准差进行归一化
- :param _img: 待归一化的图像
- :return: 归一化后的图像
- """
- return (_img.astype(np.float32) / 255 - self.config.mean) / self.config.std
- def width_pad_img(self, _img, _target_width, _pad_value=0):
- """
- 将图像进行高度不变,宽度的调整的pad
- :param _img: 待pad的图像
- :param _target_width: 目标宽度
- :param _pad_value: pad的值
- :return: pad完成后的图像
- """
- _height, _width, _channels = _img.shape
- to_return_img = np.ones([_height, _target_width, _channels], dtype=_img.dtype) * _pad_value
- to_return_img[:_height, :_width, :] = _img
- return to_return_img
|