DetCollateFN.py 1016 B

123456789101112131415161718192021222324252627282930313233
  1. # -*- coding: utf-8 -*-
  2. # @Time : 2020/6/22 14:16
  3. # @Author : zhoujun
  4. import PIL
  5. import numpy as np
  6. import torch
  7. from torchvision import transforms
  8. __all__ = ['DetCollectFN']
  9. class DetCollectFN:
  10. def __init__(self, *args, **kwargs):
  11. pass
  12. def __call__(self, batch):
  13. data_dict = {}
  14. to_tensor_keys = []
  15. for sample in batch:
  16. for k, v in sample.items():
  17. if k not in data_dict:
  18. data_dict[k] = []
  19. if isinstance(v, (np.ndarray, torch.Tensor, PIL.Image.Image)):
  20. if k not in to_tensor_keys:
  21. to_tensor_keys.append(k)
  22. if isinstance(v, np.ndarray):
  23. v = torch.tensor(v)
  24. if isinstance(v, PIL.Image.Image):
  25. v = transforms.ToTensor()(v)
  26. data_dict[k].append(v)
  27. for k in to_tensor_keys:
  28. data_dict[k] = torch.stack(data_dict[k], 0)
  29. return data_dict