__init__.py 3.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899
  1. import copy
  2. from addict import Dict
  3. from torch.utils.data import DataLoader
  4. from .RecDataSet import RecDataLoader, RecTextLineDataset, RecLmdbDataset
  5. from .DetDataSet import JsonDataset
  6. from .RecCollateFn import RecCollateFn
  7. from .DetCollateFN import DetCollectFN
  8. from .DetDateSetPse import MyDataset
  9. from .DetDataSetFce import FCEDataset
  10. __all__ = ['build_dataloader']
  11. support_dataset = ['RecTextLineDataset', 'RecLmdbDataset', 'DetTextLineDataset','JsonDataset','MyDataset','FCEDataset']
  12. support_loader = ['RecDataLoader', 'DataLoader']
  13. def build_dataset(config):
  14. """
  15. 根据配置构造dataset
  16. :param config: 数据集相关的配置,一般为 config['dataset']['train']['dataset] or config['dataset']['eval']['dataset]
  17. :return: 根据配置构造好的 DataSet 类对象
  18. """
  19. dataset_type = config.pop('type')
  20. assert dataset_type in support_dataset, f'{dataset_type} is not developed yet!, only {support_dataset} are support now'
  21. dataset_class = eval(dataset_type)(config)
  22. return dataset_class
  23. def build_loader(dataset, config):
  24. """
  25. 根据配置构造 dataloader, 包含两个步骤,1. 构造 collate_fn, 2. 构造 dataloader
  26. :param dataset: 继承自 torch.utils.data.DataSet的类对象
  27. :param config: loader 相关的配置,一般为 config['dataset']['train']['loader] or config['dataset']['eval']['loader]
  28. :return: 根据配置构造好的 DataSet 类对象
  29. """
  30. dataloader_type = config.pop('type')
  31. assert dataloader_type in support_loader, f'{dataloader_type} is not developed yet!, only {support_loader} are support now'
  32. # build collate_fn
  33. if 'collate_fn' in config:
  34. config['collate_fn']['dataset'] = dataset
  35. collate_fn = build_collate_fn(config.pop('collate_fn'))
  36. else:
  37. collate_fn = None
  38. dataloader_class = eval(dataloader_type)(dataset=dataset, collate_fn=collate_fn, **config ,pin_memory=True)
  39. return dataloader_class
  40. def build_loader_add(dataset, config):
  41. """
  42. 根据配置构造 dataloader, 包含两个步骤,1. 构造 collate_fn, 2. 构造 dataloader
  43. :param dataset: 继承自 torch.utils.data.DataSet的类对象
  44. :param config: loader 相关的配置,一般为 config['dataset']['train']['loader] or config['dataset']['eval']['loader]
  45. :return: 根据配置构造好的 DataSet 类对象
  46. """
  47. dataloader_type = config.pop('type')
  48. assert dataloader_type in support_loader, f'{dataloader_type} is not developed yet!, only {support_loader} are support now'
  49. # build collate_fn
  50. if 'collate_fn' in config:
  51. config['collate_fn']['dataset'] = dataset
  52. collate_fn = build_collate_fn(config.pop('collate_fn'))
  53. else:
  54. collate_fn = None
  55. dataloader_class = eval(dataloader_type)(dataset=dataset, collate_fn=collate_fn, **config ,pin_memory=True)
  56. return dataloader_class
  57. def build_collate_fn(config):
  58. """
  59. 根据配置构造 collate_fn
  60. :param config: collate_fn 相关的配置
  61. :return: 根据配置构造好的 collate_fn 类对象
  62. """
  63. collate_fn_type = config.pop('type')
  64. if len(collate_fn_type) == 0:
  65. return None
  66. collate_fn_class = eval(collate_fn_type)(**config)
  67. return collate_fn_class
  68. def build_dataloader(config):
  69. """
  70. 根据配置构造 dataloader, 包含两个步骤,1. 构造 dataset, 2. 构造 dataloader
  71. :param config: 数据集相关的配置,一般为 config['dataset']['train'] or config['dataset']['eval']
  72. :return: 根据配置构造好的 DataLoader 类对象
  73. """
  74. # build dataset
  75. copy_config = copy.deepcopy(config)
  76. copy_config = Dict(copy_config)
  77. dataset = build_dataset(copy_config.dataset)
  78. # build loader
  79. loader = build_loader(dataset, copy_config.loader)
  80. return loader