CombinedLoss.py 1023 B

123456789101112131415161718192021222324252627
  1. import torch
  2. import torch.nn as nn
  3. from .distillation_loss import DistillationDilaDBLoss,DistillationDBLoss,DistillationDMLLoss
  4. class CombinedLoss(nn.Module):
  5. def __init__(self, _cfg_list=None):
  6. super().__init__()
  7. self.loss_func = []
  8. self.loss_weight = []
  9. for key, val in _cfg_list['combine_list'].items():
  10. self.loss_weight.append(val.pop('weight'))
  11. self.loss_func.append(eval(key)(**val))
  12. def forward(self, input, batch, **kwargs):
  13. loss_dict = {}
  14. loss_all = 0.
  15. for idx, loss_func in enumerate(self.loss_func):
  16. loss = loss_func(input, batch, **kwargs)
  17. weight = self.loss_weight[idx]
  18. loss = {key: loss[key] * weight for key in loss}
  19. if 'loss' in loss:
  20. loss_all =torch.add(loss_all, loss['loss'])
  21. else:
  22. loss_all += torch.add(list(loss.values()))
  23. loss_dict.update(loss)
  24. loss_dict['loss'] = loss_all
  25. return loss_dict