123456789101112131415161718192021222324252627 |
- import torch
- import torch.nn as nn
- from .distillation_loss import DistillationDilaDBLoss,DistillationDBLoss,DistillationDMLLoss
- class CombinedLoss(nn.Module):
- def __init__(self, _cfg_list=None):
- super().__init__()
- self.loss_func = []
- self.loss_weight = []
- for key, val in _cfg_list['combine_list'].items():
- self.loss_weight.append(val.pop('weight'))
- self.loss_func.append(eval(key)(**val))
- def forward(self, input, batch, **kwargs):
- loss_dict = {}
- loss_all = 0.
- for idx, loss_func in enumerate(self.loss_func):
- loss = loss_func(input, batch, **kwargs)
- weight = self.loss_weight[idx]
- loss = {key: loss[key] * weight for key in loss}
- if 'loss' in loss:
- loss_all =torch.add(loss_all, loss['loss'])
- else:
- loss_all += torch.add(list(loss.values()))
- loss_dict.update(loss)
- loss_dict['loss'] = loss_all
- return loss_dict
|