12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758 |
- import os
- import copy
- import torch
- from torch import nn
- from addict import Dict
- from .DetModel import DetModel
- from addict import Dict as AttrDict
- __all__ = ['DistillationModel']
- def load_pretrained_params(_model, _path):
- if _path is None:
- return False
- if not os.path.exists(_path):
- print(f'The pretrained_model {_path} does not exists')
- return False
- params = torch.load(_path)
- state_dict = params['state_dict']
- state_dict_no_module = {k.replace('module.', ''): v for k, v in state_dict.items()}
- _model.load_state_dict(state_dict_no_module)
- return _model
- class DistillationModel(nn.Module):
- def __init__(self, config):
- super(DistillationModel, self).__init__()
- self.model_dict = nn.ModuleDict()
- self.model_name_list = []
- sub_model_cfgs = config['models']
- for key in sub_model_cfgs:
- sub_cfg = copy.deepcopy(sub_model_cfgs[key])
- sub_cfg.pop('type')
- freeze_params = False
- pretrained = None
- if 'freeze_params' in sub_cfg:
- freeze_params = sub_cfg.pop('freeze_params')
- if 'pretrained' in sub_cfg:
- pretrained = sub_cfg.pop('pretrained')
- model = DetModel(Dict(sub_cfg))
- if pretrained is not None:
- model = load_pretrained_params(model, pretrained)
- if freeze_params:
- for para in model.parameters():
- para.requires_grad = False
- model.training = False
- self.model_dict[key] = model
- self.model_name_list.append(key)
- def forward(self, x):
- result_dict = dict()
- for idx, model_name in enumerate(self.model_name_list):
- result_dict[model_name] = self.model_dict[model_name](x)
- return result_dict
|