DetDataSetFce.py 6.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160
  1. import os
  2. import cv2
  3. import json
  4. import copy
  5. import numpy as np
  6. from tqdm import tqdm
  7. from torch.utils.data import Dataset
  8. from torchvision import transforms
  9. from torchocr.datasets.det_modules import *
  10. def load_json(file_path: str):
  11. with open(file_path, 'r', encoding='utf8') as f:
  12. content = json.load(f)
  13. return content
  14. class FCEDataset(Dataset):
  15. def __init__(self, config):
  16. assert config.img_mode in ['RGB', 'BRG', 'GRAY']
  17. self.ignore_tags = config.ignore_tags
  18. # 加载字符级标注
  19. self.load_char_annotation = False
  20. self.data_list = self.load_data(config.file)
  21. item_keys = ['img_path', 'img_name', 'text_polys', 'texts', 'ignore_tags']
  22. for item in item_keys:
  23. assert item in self.data_list[0], 'data_list from load_data must contains {}'.format(item_keys)
  24. self.img_mode = config.img_mode
  25. self.filter_keys = config.filter_keys
  26. self._init_pre_processes(config.pre_processes)
  27. self.transform = transforms.Compose([
  28. transforms.ToTensor(),
  29. transforms.Normalize(mean=config.mean, std=config.std)
  30. ])
  31. def _init_pre_processes(self, pre_processes):
  32. self.aug = []
  33. if pre_processes is not None:
  34. for aug in pre_processes:
  35. if 'args' not in aug:
  36. args = {}
  37. else:
  38. args = aug['args']
  39. if isinstance(args, dict):
  40. cls = eval(aug['type'])(**args)
  41. else:
  42. cls = eval(aug['type'])(args)
  43. self.aug.append(cls)
  44. def load_data(self, path: str) -> list:
  45. """
  46. 从json文件中读取出 文本行的坐标和gt,字符的坐标和gt
  47. :params path: 存储数据的文件夹或者文件
  48. return a dict ,包含了,'img_path','img_name','text_polys','texts','ignore_tags'
  49. """
  50. data_list = []
  51. content = load_json(path)
  52. for gt in tqdm(content['data_list'], desc='read file {}'.format(path)):
  53. try:
  54. img_path = os.path.join(content['data_root'], gt['img_name'])
  55. polygons = []
  56. texts = []
  57. illegibility_list = []
  58. language_list = []
  59. max_poly_len = 0
  60. if len( gt['annotations'])==0:
  61. print(img_path)
  62. continue
  63. for annotation in gt['annotations']:
  64. if len(annotation['polygon']) == 0 or len(annotation['text']) == 0:
  65. continue
  66. max_poly_len = max(max_poly_len, len(annotation['polygon']))
  67. polygons.append(annotation['polygon'])
  68. texts.append(annotation['text'])
  69. illegibility_list.append(annotation['illegibility'])
  70. language_list.append(annotation['language'])
  71. if self.load_char_annotation:
  72. for char_annotation in annotation['chars']:
  73. if len(char_annotation['polygon']) == 0 or len(char_annotation['char']) == 0:
  74. continue
  75. polygons.append(char_annotation['polygon'])
  76. texts.append(char_annotation['char'])
  77. illegibility_list.append(char_annotation['illegibility'])
  78. language_list.append(char_annotation['language'])
  79. ex_polygons = []
  80. for pl in polygons:
  81. ex_pl = pl + [pl[-1]] * (max_poly_len - len(pl))
  82. ex_polygons.append(ex_pl)
  83. data_list.append(
  84. {'img_path': img_path, 'img_name': gt['img_name'], 'text_polys': np.array(ex_polygons, dtype=np.float32),
  85. 'texts': texts, 'ignore_tags': illegibility_list})
  86. except:
  87. print(f'error gt:{img_path}')
  88. return data_list
  89. def apply_pre_processes(self, data):
  90. for aug in self.aug:
  91. data = aug(data)
  92. return data
  93. def __getitem__(self, index):
  94. # try:
  95. data = copy.deepcopy(self.data_list[index])
  96. im = cv2.imread(data['img_path'], 1 if self.img_mode != 'GRAY' else 0)
  97. if self.img_mode == 'RGB':
  98. try:
  99. im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB)
  100. except:
  101. print(data['img_path'])
  102. data['img'] = im
  103. data['shape'] = [im.shape[0], im.shape[1]]
  104. data = self.apply_pre_processes(data)
  105. if self.transform:
  106. data['img'] = self.transform(data['img'])
  107. data['text_polys'] = data['text_polys']
  108. if len(self.filter_keys):
  109. data_dict = {}
  110. for k, v in data.items():
  111. if k not in self.filter_keys:
  112. data_dict[k] = v
  113. return data_dict
  114. else:
  115. return data
  116. def __len__(self):
  117. return len(self.data_list)
  118. if __name__ == '__main__':
  119. import torch
  120. from torch.utils.data import DataLoader
  121. # from config.cfg_det_db import config
  122. from local.cfg.cfg_det_fce import config
  123. from torchocr.utils import show_img, draw_bbox
  124. from matplotlib import pyplot as plt
  125. dataset = JsonDataset(config.dataset.train.dataset)
  126. train_loader = DataLoader(dataset=dataset, batch_size=1, shuffle=True, num_workers=0)
  127. for i, data in enumerate(tqdm(train_loader)):
  128. # print(data['img_path'])
  129. # img = data['img'][0].numpy().transpose(1, 2, 0)
  130. # shrink_label = data['shrink_map'].numpy().transpose(1, 2, 0)
  131. # threshold_label = data['threshold_map'].numpy().transpose(1, 2, 0)
  132. # show_img(img, title='img')
  133. # show_img(shrink_label, title='shrink_label')
  134. # show_img(threshold_label, title='threshold_label')
  135. # plt.show()
  136. # print(threshold_label.shape, threshold_label.shape, img.shape)
  137. # show_img(img[0].numpy().transpose(1, 2, 0), title='img')
  138. # show_img((shrink_label[0].to(torch.float)).numpy(), title='shrink_label')
  139. # show_img((threshold_label[0].to(torch.float)).numpy(), title='threshold_label')
  140. # img = draw_bbox(img[0].numpy().transpose(1, 2, 0), np.array(data['text_polys']))
  141. # show_img(img, title='draw_bbox')
  142. # plt.show()
  143. pass