DBLoss.py 2.5 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859
  1. # -*- coding: utf-8 -*-
  2. # @Time : 2019/8/23 21:56
  3. # @Author : zhoujun
  4. from torch import nn
  5. from torchocr.networks.losses.DetBasicLoss import BalanceCrossEntropyLoss, MaskL1Loss, DiceLoss, BalanceLoss
  6. class DBLoss(nn.Module):
  7. def __init__(self, balance_loss=True, main_loss_type='DiceLoss', alpha=1.0, beta=10, ohem_ratio=3, reduction='mean',
  8. eps=1e-6):
  9. """
  10. Implement PSE Loss.
  11. :param alpha: binary_map loss 前面的系数
  12. :param beta: threshold_map loss 前面的系数
  13. :param ohem_ratio: OHEM的比例
  14. :param reduction: 'mean' or 'sum'对 batch里的loss 算均值或求和
  15. """
  16. super().__init__()
  17. assert reduction in ['mean', 'sum'], " reduction must in ['mean','sum']"
  18. self.alpha = alpha
  19. self.beta = beta
  20. # self.bce_loss = BalanceCrossEntropyLoss(negative_ratio=ohem_ratio)
  21. self.bce_loss = BalanceLoss(
  22. balance_loss=balance_loss,
  23. main_loss_type=main_loss_type,
  24. negative_ratio=ohem_ratio)
  25. self.dice_loss = DiceLoss(eps=eps)
  26. self.l1_loss = MaskL1Loss(eps=eps)
  27. self.reduction = reduction
  28. def forward(self, pred, batch):
  29. """
  30. :param pred:
  31. :param batch: bach为一个dict{
  32. 'shrink_map': 收缩图,b*c*h,w
  33. 'shrink_mask: 收缩图mask,b*c*h,w
  34. 'threshold_map: 二值化边界gt,b*c*h,w
  35. 'threshold_mask: 二值化边界gtmask,b*c*h,w
  36. }
  37. :return:
  38. """
  39. shrink_maps = pred[:, 0, :, :]
  40. threshold_maps = pred[:, 1, :, :]
  41. binary_maps = pred[:, 2, :, :]
  42. loss_shrink_maps = self.alpha * self.bce_loss(shrink_maps, batch['shrink_map'], batch['shrink_mask'])
  43. loss_threshold_maps = self.beta * self.l1_loss(threshold_maps, batch['threshold_map'], batch['threshold_mask'])
  44. loss_dict = dict(loss_shrink_maps=loss_shrink_maps, loss_threshold_maps=loss_threshold_maps)
  45. if pred.size()[1] > 2:
  46. loss_binary_maps = self.dice_loss(binary_maps, batch['shrink_map'], batch['shrink_mask'])
  47. loss_dict['loss_binary_maps'] = loss_binary_maps
  48. loss_all = loss_shrink_maps + loss_threshold_maps + loss_binary_maps
  49. loss_dict['loss'] = loss_all
  50. else:
  51. loss_dict['loss'] = loss_shrink_maps
  52. return loss_dict