DetDataSet.py 6.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155
  1. # -*- coding: utf-8 -*-
  2. # @Time : 2020/6/22 10:53
  3. # @Author : zhoujun
  4. import os
  5. import cv2
  6. import json
  7. import copy
  8. import numpy as np
  9. from tqdm import tqdm
  10. from torch.utils.data import Dataset
  11. from torchvision import transforms
  12. from torchocr.datasets.det_modules import *
  13. def load_json(file_path: str):
  14. with open(file_path, 'r', encoding='utf8') as f:
  15. content = json.load(f)
  16. return content
  17. class JsonDataset(Dataset):
  18. """
  19. from https://github.com/WenmuZhou/OCR_DataSet/blob/master/dataset/det.py
  20. """
  21. def __init__(self, config):
  22. assert config.img_mode in ['RGB', 'BRG', 'GRAY']
  23. self.ignore_tags = config.ignore_tags
  24. # 加载字符级标注
  25. self.load_char_annotation = False
  26. self.data_list = self.load_data(config.file)
  27. item_keys = ['img_path', 'img_name', 'text_polys', 'texts', 'ignore_tags']
  28. for item in item_keys:
  29. assert item in self.data_list[0], 'data_list from load_data must contains {}'.format(item_keys)
  30. self.img_mode = config.img_mode
  31. self.filter_keys = config.filter_keys
  32. self._init_pre_processes(config.pre_processes)
  33. self.transform = transforms.Compose([
  34. transforms.ToTensor(),
  35. transforms.Normalize(mean=config.mean, std=config.std)
  36. ])
  37. def _init_pre_processes(self, pre_processes):
  38. self.aug = []
  39. if pre_processes is not None:
  40. for aug in pre_processes:
  41. if 'args' not in aug:
  42. args = {}
  43. else:
  44. args = aug['args']
  45. if isinstance(args, dict):
  46. cls = eval(aug['type'])(**args)
  47. else:
  48. cls = eval(aug['type'])(args)
  49. self.aug.append(cls)
  50. def load_data(self, path: str) -> list:
  51. """
  52. 从json文件中读取出 文本行的坐标和gt,字符的坐标和gt
  53. :params path: 存储数据的文件夹或者文件
  54. return a dict ,包含了,'img_path','img_name','text_polys','texts','ignore_tags'
  55. """
  56. data_list = []
  57. content = load_json(path)
  58. for gt in tqdm(content['data_list'], desc='read file {}'.format(path)):
  59. try:
  60. img_path = os.path.join(content['data_root'], gt['img_name'])
  61. polygons = []
  62. texts = []
  63. illegibility_list = []
  64. language_list = []
  65. for annotation in gt['annotations']:
  66. if len(annotation['polygon']) == 0 or len(annotation['text']) == 0:
  67. continue
  68. polygons.append(annotation['polygon'])
  69. texts.append(annotation['text'])
  70. illegibility_list.append(annotation['illegibility'])
  71. language_list.append(annotation['language'])
  72. if self.load_char_annotation:
  73. for char_annotation in annotation['chars']:
  74. if len(char_annotation['polygon']) == 0 or len(char_annotation['char']) == 0:
  75. continue
  76. polygons.append(char_annotation['polygon'])
  77. texts.append(char_annotation['char'])
  78. illegibility_list.append(char_annotation['illegibility'])
  79. language_list.append(char_annotation['language'])
  80. data_list.append({'img_path': img_path, 'img_name': gt['img_name'], 'text_polys': polygons,
  81. 'texts': texts, 'ignore_tags': illegibility_list})
  82. except:
  83. print(f'error gt:{img_path}')
  84. return data_list
  85. def apply_pre_processes(self, data):
  86. for aug in self.aug:
  87. data = aug(data)
  88. return data
  89. def __getitem__(self, index):
  90. data = copy.deepcopy(self.data_list[index])
  91. im = cv2.imread(data['img_path'], 1 if self.img_mode != 'GRAY' else 0)
  92. if self.img_mode == 'RGB':
  93. try:
  94. im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB)
  95. except:
  96. print(data['img_path'])
  97. data['img'] = im
  98. data['shape'] = [im.shape[0], im.shape[1]]
  99. data = self.apply_pre_processes(data)
  100. if self.transform:
  101. data['img'] = self.transform(data['img'])
  102. data['text_polys'] = data['text_polys']
  103. if len(self.filter_keys):
  104. data_dict = {}
  105. for k, v in data.items():
  106. if k not in self.filter_keys:
  107. data_dict[k] = v
  108. return data_dict
  109. else:
  110. return data
  111. def __len__(self):
  112. return len(self.data_list)
  113. if __name__ == '__main__':
  114. import torch
  115. from torch.utils.data import DataLoader
  116. # from config.cfg_det_db import config
  117. from local.cfg.cfg_det_db_latin import config
  118. from torchocr.utils import show_img, draw_bbox
  119. from matplotlib import pyplot as plt
  120. dataset = JsonDataset(config.dataset.train.dataset)
  121. train_loader = DataLoader(dataset=dataset, batch_size=1, shuffle=True, num_workers=0)
  122. for i, data in enumerate(tqdm(train_loader)):
  123. # print(data['img_path'])
  124. img = data['img'][0].numpy().transpose(1, 2, 0)
  125. shrink_label = data['shrink_map'].numpy().transpose(1, 2, 0)
  126. threshold_label = data['threshold_map'].numpy().transpose(1, 2, 0)
  127. show_img(img, title='img')
  128. show_img(shrink_label, title='shrink_label')
  129. show_img(threshold_label, title='threshold_label')
  130. plt.show()
  131. # print(threshold_label.shape, threshold_label.shape, img.shape)
  132. # show_img(img[0].numpy().transpose(1, 2, 0), title='img')
  133. # show_img((shrink_label[0].to(torch.float)).numpy(), title='shrink_label')
  134. # show_img((threshold_label[0].to(torch.float)).numpy(), title='threshold_label')
  135. # img = draw_bbox(img[0].numpy().transpose(1, 2, 0), np.array(data['text_polys']))
  136. # show_img(img, title='draw_bbox')
  137. # plt.show()
  138. pass