123456789101112131415161718192021222324252627282930313233 |
- # -*- coding: utf-8 -*-
- # @Time : 2020/6/22 14:16
- # @Author : zhoujun
- import PIL
- import numpy as np
- import torch
- from torchvision import transforms
- __all__ = ['DetCollectFN']
- class DetCollectFN:
- def __init__(self, *args, **kwargs):
- pass
- def __call__(self, batch):
- data_dict = {}
- to_tensor_keys = []
- for sample in batch:
- for k, v in sample.items():
- if k not in data_dict:
- data_dict[k] = []
- if isinstance(v, (np.ndarray, torch.Tensor, PIL.Image.Image)):
- if k not in to_tensor_keys:
- to_tensor_keys.append(k)
- if isinstance(v, np.ndarray):
- v = torch.tensor(v)
- if isinstance(v, PIL.Image.Image):
- v = transforms.ToTensor()(v)
- data_dict[k].append(v)
- for k in to_tensor_keys:
- data_dict[k] = torch.stack(data_dict[k], 0)
- return data_dict
|