DistillationDetModel.py 1.8 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758
  1. import os
  2. import copy
  3. import torch
  4. from torch import nn
  5. from addict import Dict
  6. from .DetModel import DetModel
  7. from addict import Dict as AttrDict
  8. __all__ = ['DistillationModel']
  9. def load_pretrained_params(_model, _path):
  10. if _path is None:
  11. return False
  12. if not os.path.exists(_path):
  13. print(f'The pretrained_model {_path} does not exists')
  14. return False
  15. params = torch.load(_path)
  16. state_dict = params['state_dict']
  17. state_dict_no_module = {k.replace('module.', ''): v for k, v in state_dict.items()}
  18. _model.load_state_dict(state_dict_no_module)
  19. return _model
  20. class DistillationModel(nn.Module):
  21. def __init__(self, config):
  22. super(DistillationModel, self).__init__()
  23. self.model_dict = nn.ModuleDict()
  24. self.model_name_list = []
  25. sub_model_cfgs = config['models']
  26. for key in sub_model_cfgs:
  27. sub_cfg = copy.deepcopy(sub_model_cfgs[key])
  28. sub_cfg.pop('type')
  29. freeze_params = False
  30. pretrained = None
  31. if 'freeze_params' in sub_cfg:
  32. freeze_params = sub_cfg.pop('freeze_params')
  33. if 'pretrained' in sub_cfg:
  34. pretrained = sub_cfg.pop('pretrained')
  35. model = DetModel(Dict(sub_cfg))
  36. if pretrained is not None:
  37. model = load_pretrained_params(model, pretrained)
  38. if freeze_params:
  39. for para in model.parameters():
  40. para.requires_grad = False
  41. model.training = False
  42. self.model_dict[key] = model
  43. self.model_name_list.append(key)
  44. def forward(self, x):
  45. result_dict = dict()
  46. for idx, model_name in enumerate(self.model_name_list):
  47. result_dict[model_name] = self.model_dict[model_name](x)
  48. return result_dict