RecDataSet.py 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461
  1. # -*-coding:utf-8-*-
  2. """
  3. @Author: Jeffery Sheng (Zhenfei Sheng)
  4. @Time: 2020/5/21 19:44
  5. @File: RecDataSet.py
  6. """
  7. import six
  8. import cv2
  9. import torch
  10. import numpy as np
  11. from PIL import Image
  12. from torch.utils.data import Dataset, DataLoader
  13. from torchocr.utils.CreateRecAug import cv2pil, pil2cv, RandomBrightness, RandomContrast, \
  14. RandomLine, RandomSharpness, Compress, Rotate, \
  15. Blur, MotionBlur, Salt, AdjustResolution
  16. import re
  17. class RecTextLineDataset(Dataset):
  18. def __init__(self, config):
  19. """
  20. 文本行 DataSet, 用于处理标注格式为 `img_path\tlabel` 的标注格式
  21. :param config: 相关配置,一般为 config['dataset']['train']['dataset] or config['dataset']['eval']['dataset]
  22. 其主要应包含如下字段: file: 标注文件路径
  23. input_h: 图片的目标高
  24. mean: 归一化均值
  25. std: 归一化方差
  26. augmentation: 使用使用数据增强
  27. :return None
  28. """
  29. self.augmentation = config.augmentation
  30. self.process = RecDataProcess(config)
  31. with open(config.alphabet, 'r', encoding='utf-8') as file:
  32. alphabet = ''.join([s.strip('\n') for s in file.readlines()])
  33. # alphabet += ' '
  34. alphabet = alphabet.replace("blank"," ") #add
  35. self.str2idx = {c: i for i, c in enumerate(alphabet)}
  36. self.labels = []
  37. # if "test.txt" in config.file:
  38. with open(config.file, 'r', encoding='utf-8') as f_reader:
  39. for m_line in f_reader.readlines():
  40. m_line=m_line.strip()
  41. params = m_line.split(' ')
  42. # print(params)
  43. if len(params) >= 2:
  44. m_image_name = params[0]
  45. m_image_name = '/data2/znj/CRNN_Chinese_Characters_Rec/data/data/python_znj/Lets_OCR/recognizer/crnn/data/images/'+m_image_name
  46. m_gt_text = params[1:]
  47. # print(m_gt_text)
  48. m_gt_text = "".join([alphabet[int(idx)] if int(idx)<len(alphabet) else '#none#' for idx in m_gt_text])
  49. # if True in [c not in self.str2idx for c in m_gt_text]:
  50. # continue
  51. if "#none#" in m_gt_text:
  52. continue
  53. self.labels.append((m_image_name, m_gt_text))
  54. # with open(config.file, 'r', encoding='utf-8') as f_reader:
  55. # for m_line in f_reader.readlines():
  56. # is_skip = False
  57. # m_line=m_line.strip()
  58. # params = m_line.split(' ')
  59. # if len(params) >= 2:
  60. # m_image_name = params[0]
  61. # m_image_name = '/data2/znj/text_renderer/output2/default/' + m_image_name + '.jpg'
  62. # m_gt_text = params[1]
  63. # for w in m_gt_text:
  64. # if w not in alphabet:
  65. # is_skip = True
  66. # break
  67. # if is_skip:
  68. # continue
  69. # self.labels.append((m_image_name, m_gt_text))
  70. # add
  71. # if "train.txt" in config.file:
  72. # with open("/data2/znj/PytorchOCR/data/train.txt", 'r', encoding='utf-8') as f_reader:
  73. # for m_line in f_reader.readlines():
  74. # m_line=m_line.strip()
  75. # params = m_line.split('\t')
  76. # # print(params)
  77. # if len(params) >= 3:
  78. # m_image_name = params[2]
  79. # m_image_name = '/data2/znj/PytorchOCR/data/image/'+m_image_name
  80. # m_gt_text = params[3]
  81. # # print(m_gt_text)
  82. # # m_gt_text = "".join([alphabet[int(idx)] if int(idx)<len(alphabet) else '#none#' for idx in m_gt_text])
  83. # # if True in [c not in self.str2idx for c in m_gt_text]:
  84. # # continue
  85. # _m_gt_text = "".join([str(self.str2idx.get(i,'#none#')) for i in m_gt_text])
  86. # if "#none#" in _m_gt_text:
  87. # continue
  88. # self.labels.append((m_image_name, m_gt_text))
  89. # add
  90. # if "train.txt" in config.file:
  91. # with open("/data2/znj/text_renderer/output3/default/tmp_labels.txt", 'r', encoding='utf-8') as f_reader:
  92. # for m_line in f_reader.readlines():
  93. # is_skip = False
  94. # m_line=m_line.strip()
  95. # params = m_line.split(' ')
  96. # # print(params)
  97. # if len(params) >= 2:
  98. # m_image_name = params[0]
  99. # m_image_name = '/data2/znj/text_renderer/output3/default/' + m_image_name + '.jpg'
  100. # m_gt_text = params[1]
  101. # for w in m_gt_text:
  102. # if w not in alphabet:
  103. # is_skip = True
  104. # break
  105. # if is_skip:
  106. # continue
  107. # self.labels.append((m_image_name, m_gt_text))
  108. # test
  109. with open(config.file, 'r', encoding='utf-8') as f_reader:
  110. for m_line in f_reader.readlines():
  111. m_line=m_line.strip()
  112. params = m_line.split(' ')
  113. # print(params)
  114. if len(params) >= 2:
  115. m_image_name = params[0]
  116. m_gt_text = params[1:]
  117. # print(m_gt_text)
  118. m_gt_text = "".join([alphabet[int(idx)] if int(idx)<len(alphabet) else '#none#' for idx in m_gt_text])
  119. # if True in [c not in self.str2idx for c in m_gt_text]:
  120. # continue
  121. if "#none#" in m_gt_text:
  122. continue
  123. # self.labels.append((m_image_name, m_gt_text))
  124. self.labels.append((m_image_name, 'a123'))
  125. print(self.labels)
  126. # paddle 识别数据
  127. if "log4.log" in config.file:
  128. with open(config.file, 'r', encoding='utf-8') as f_reader:
  129. for m_line in f_reader.readlines():
  130. m_line=m_line.strip()
  131. iamge_path, line_split2 = re.split(" ", m_line, maxsplit=1)
  132. text, box = re.split(" \[\[", line_split2, maxsplit=1)
  133. box = '[[' + box
  134. box,_ = re.split(" rec_res:", box, maxsplit=1)
  135. self.labels.append((iamge_path, text,box))
  136. print(self.labels)
  137. def _find_max_length(self):
  138. return max({len(_[1]) for _ in self.labels})
  139. def __len__(self):
  140. return len(self.labels)
  141. def __getitem__(self, index):
  142. # get img_path and trans
  143. # img_path, trans = self.labels[index]
  144. label = self.labels[index]
  145. if len(label)==2:
  146. img_path, trans = label
  147. # read img
  148. img = cv2.imread(img_path)
  149. img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
  150. elif len(label)==3:
  151. img_path, trans,box = label
  152. # read img
  153. img = cv2.imread(img_path)
  154. img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
  155. bbox = eval(box)
  156. x1 = int(min([i[0] for i in bbox]))
  157. x2 = int(max([i[0] for i in bbox]))
  158. y1 = int(min([i[1] for i in bbox]))
  159. y2 = int(max([i[1] for i in bbox]))
  160. img = img[y1:y2, x1:x2]
  161. # do aug
  162. if len(label)==2:
  163. if self.augmentation:
  164. img = pil2cv(self.process.aug_img(cv2pil(img)))
  165. return {'img': img, 'label': trans}
  166. class RecTextLineDataset2(Dataset):
  167. def __init__(self, config):
  168. """
  169. 文本行 DataSet, 用于处理标注格式为 `img_path\tlabel` 的标注格式
  170. :param config: 相关配置,一般为 config['dataset']['train']['dataset] or config['dataset']['eval']['dataset]
  171. 其主要应包含如下字段: file: 标注文件路径
  172. input_h: 图片的目标高
  173. mean: 归一化均值
  174. std: 归一化方差
  175. augmentation: 使用使用数据增强
  176. :return None
  177. """
  178. self.augmentation = config.augmentation
  179. self.process = RecDataProcess(config)
  180. self.labels = []
  181. with open("", 'r', encoding='utf-8') as f_reader:
  182. for m_line in f_reader.readlines():
  183. m_line=m_line.strip()
  184. params = m_line.split(' ')
  185. if len(params) >= 2:
  186. m_image_name = params[0]
  187. m_gt_text = params[1:]
  188. m_gt_text = "".join([alphabet[int(idx)] if int(idx)<len(alphabet) else '#none#' for idx in m_gt_text])
  189. if "#none#" in m_gt_text:
  190. continue
  191. self.labels.append((m_image_name, 'a123'))
  192. print(self.labels)
  193. def _find_max_length(self):
  194. return max({len(_[1]) for _ in self.labels})
  195. def __len__(self):
  196. return len(self.labels)
  197. def __getitem__(self, index):
  198. # get img_path and trans
  199. img_path, trans = self.labels[index]
  200. # read img
  201. img = cv2.imread(img_path)
  202. img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
  203. # do aug
  204. if self.augmentation:
  205. img = pil2cv(self.process.aug_img(cv2pil(img)))
  206. return {'img': img, 'label': trans}
  207. class RecLmdbDataset(Dataset):
  208. def __init__(self, config):
  209. """
  210. Lmdb DataSet, 用于处理转换为 lmdb 文件后的数据集
  211. :param config: 相关配置,一般为 config['dataset']['train']['dataset] or config['dataset']['eval']['dataset]
  212. 其主要应包含如下字段: file: 标注文件路径
  213. input_h: 图片的目标高
  214. mean: 归一化均值
  215. std: 归一化方差
  216. augmentation: 使用使用数据增强
  217. :return None
  218. """
  219. import lmdb, sys
  220. self.env = lmdb.open(config.file, max_readers=32, readonly=True, lock=False, readahead=False, meminit=False)
  221. if not self.env:
  222. print('cannot create lmdb from %s' % (config.file))
  223. sys.exit(0)
  224. self.augmentation = config.augmentation
  225. self.process = RecDataProcess(config)
  226. self.filtered_index_list = []
  227. self.labels = []
  228. self.str2idx = {c: i for i, c in enumerate(config.alphabet)}
  229. with self.env.begin(write=False) as txn:
  230. nSamples = int(txn.get('num-samples'.encode()))
  231. self.nSamples = nSamples
  232. for index in range(self.nSamples):
  233. index += 1 # lmdb starts with 1
  234. label_key = 'label-%09d'.encode() % index
  235. label = txn.get(label_key).decode('utf-8')
  236. # todo 添加 过滤最长
  237. # if len(label) > config.max_len:
  238. # # print(f'The length of the label is longer than max_length: length
  239. # # {len(label)}, {label} in dataset {self.root}')
  240. # continue
  241. if True in [c not in self.str2idx for c in label]:
  242. continue
  243. # By default, images containing characters which are not in opt.character are filtered.
  244. # You can add [UNK] token to `opt.character` in utils.py instead of this filtering.
  245. self.labels.append(label)
  246. self.filtered_index_list.append(index)
  247. def _find_max_length(self):
  248. return max({len(_) for _ in self.labels})
  249. def __getitem__(self, index):
  250. index = self.filtered_index_list[index]
  251. with self.env.begin(write=False) as txn:
  252. label_key = 'label-%09d'.encode() % index
  253. label = txn.get(label_key).decode('utf-8')
  254. img_key = 'image-%09d'.encode() % index
  255. imgbuf = txn.get(img_key)
  256. buf = six.BytesIO()
  257. buf.write(imgbuf)
  258. buf.seek(0)
  259. img = Image.open(buf).convert('RGB') # for color image
  260. # We only train and evaluate on alphanumerics (or pre-defined character set in rec_train.py)
  261. img = np.array(img)
  262. if self.augmentation:
  263. img = pil2cv(self.process.aug_img(cv2pil(img)))
  264. return {'img': img, 'label': label}
  265. def __len__(self):
  266. return len(self.filtered_index_list)
  267. class RecDataLoader:
  268. def __init__(self, dataset, batch_size, shuffle, num_workers, **kwargs):
  269. """
  270. 自定义 DataLoader, 主要实现数据集的按长度划分,将长度相近的放在一个 batch
  271. :param dataset: 继承自 torch.utils.data.DataSet的类对象
  272. :param batch_size: 一个 batch 的图片数量
  273. :param shuffle: 是否打乱数据集
  274. :param num_workers: 后台进程数
  275. :param kwargs: **
  276. """
  277. self.dataset = dataset
  278. self.process = dataset.process
  279. self.len_thresh = self.dataset._find_max_length() // 2
  280. self.batch_size = batch_size
  281. self.shuffle = shuffle
  282. self.num_workers = num_workers
  283. self.iteration = 0
  284. self.dataiter = None
  285. self.queue_1 = list()
  286. self.queue_2 = list()
  287. def __len__(self):
  288. return len(self.dataset) // self.batch_size if len(self.dataset) % self.batch_size == 0 \
  289. else len(self.dataset) // self.batch_size + 1
  290. def __iter__(self):
  291. return self
  292. def pack(self, batch_data):
  293. batch = {'img': [], 'label': []}
  294. # img tensor current shape: B,H,W,C
  295. all_same_height_images = [self.process.resize_with_specific_height(_['img'][0].numpy()) for _ in batch_data]
  296. max_img_w = max({m_img.shape[1] for m_img in all_same_height_images})
  297. # make sure max_img_w is integral multiple of 8
  298. max_img_w = int(np.ceil(max_img_w / 8) * 8)
  299. for i in range(len(batch_data)):
  300. _label = batch_data[i]['label'][0]
  301. img = self.process.normalize_img(self.process.width_pad_img(all_same_height_images[i], max_img_w))
  302. img = img.transpose([2, 0, 1])
  303. batch['img'].append(torch.tensor(img, dtype=torch.float))
  304. batch['label'].append(_label)
  305. batch['img'] = torch.stack(batch['img'])
  306. return batch
  307. def build(self):
  308. self.dataiter = DataLoader(self.dataset, batch_size=1, shuffle=self.shuffle,
  309. num_workers=self.num_workers).__iter__()
  310. def __next__(self):
  311. if self.dataiter == None:
  312. self.build()
  313. if self.iteration == len(self.dataset) and len(self.queue_2):
  314. batch_data = self.queue_2
  315. self.queue_2 = list()
  316. return self.pack(batch_data)
  317. if not len(self.queue_2) and not len(self.queue_1) and self.iteration == len(self.dataset):
  318. self.iteration = 0
  319. self.dataiter = None
  320. raise StopIteration
  321. # start iteration
  322. try:
  323. while True:
  324. # get data from origin dataloader
  325. temp = self.dataiter.__next__()
  326. self.iteration += 1
  327. # to different queue
  328. if len(temp['label'][0]) <= self.len_thresh:
  329. self.queue_1.append(temp)
  330. else:
  331. self.queue_2.append(temp)
  332. # to store batch data
  333. batch_data = None
  334. # queue_1 full, push to batch_data
  335. if len(self.queue_1) == self.batch_size:
  336. batch_data = self.queue_1
  337. self.queue_1 = list()
  338. # or queue_2 full, push to batch_data
  339. elif len(self.queue_2) == self.batch_size:
  340. batch_data = self.queue_2
  341. self.queue_2 = list()
  342. # start to process batch
  343. if batch_data is not None:
  344. return self.pack(batch_data)
  345. # deal with last batch
  346. except StopIteration:
  347. if self.queue_1 == []:
  348. raise StopIteration
  349. batch_data = self.queue_1
  350. self.queue_1 = list()
  351. return self.pack(batch_data)
  352. class RecDataProcess:
  353. def __init__(self, config):
  354. """
  355. 文本是被数据增广类
  356. :param config: 配置,主要用到的字段有 input_h, mean, std
  357. """
  358. self.config = config
  359. self.random_contrast = RandomContrast(probability=0.3)
  360. self.random_brightness = RandomBrightness(probability=0.3)
  361. self.random_sharpness = RandomSharpness(probability=0.3)
  362. self.compress = Compress(probability=0.3)
  363. self.rotate = Rotate(probability=0.5)
  364. self.blur = Blur(probability=0.3)
  365. self.motion_blur = MotionBlur(probability=0.3)
  366. self.salt = Salt(probability=0.3)
  367. self.adjust_resolution = AdjustResolution(probability=0.3)
  368. self.random_line = RandomLine(probability=0.3)
  369. self.random_contrast.setparam()
  370. self.random_brightness.setparam()
  371. self.random_sharpness.setparam()
  372. self.compress.setparam()
  373. self.rotate.setparam()
  374. self.blur.setparam()
  375. self.motion_blur.setparam()
  376. self.salt.setparam()
  377. self.adjust_resolution.setparam()
  378. def aug_img(self, img):
  379. img = self.random_contrast.process(img)
  380. img = self.random_brightness.process(img)
  381. img = self.random_sharpness.process(img)
  382. img = self.random_line.process(img)
  383. if img.size[1] >= 32:
  384. img = self.compress.process(img)
  385. img = self.adjust_resolution.process(img)
  386. img = self.motion_blur.process(img)
  387. img = self.blur.process(img)
  388. img = self.rotate.process(img)
  389. img = self.salt.process(img)
  390. return img
  391. def resize_with_specific_height(self, _img):
  392. """
  393. 将图像resize到指定高度
  394. :param _img: 待resize的图像
  395. :return: resize完成的图像
  396. """
  397. resize_ratio = self.config.input_h / _img.shape[0]
  398. return cv2.resize(_img, (0, 0), fx=resize_ratio, fy=resize_ratio, interpolation=cv2.INTER_LINEAR)
  399. def normalize_img(self, _img):
  400. """
  401. 根据配置的均值和标准差进行归一化
  402. :param _img: 待归一化的图像
  403. :return: 归一化后的图像
  404. """
  405. return (_img.astype(np.float32) / 255 - self.config.mean) / self.config.std
  406. def width_pad_img(self, _img, _target_width, _pad_value=0):
  407. """
  408. 将图像进行高度不变,宽度的调整的pad
  409. :param _img: 待pad的图像
  410. :param _target_width: 目标宽度
  411. :param _pad_value: pad的值
  412. :return: pad完成后的图像
  413. """
  414. _height, _width, _channels = _img.shape
  415. to_return_img = np.ones([_height, _target_width, _channels], dtype=_img.dtype) * _pad_value
  416. to_return_img[:_height, :_width, :] = _img
  417. return to_return_img