CTCLoss.py 5.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143
  1. from torch import nn
  2. import torch
  3. class CTCLoss(nn.Module):
  4. def __init__(self, loss_cfg, reduction='mean'):
  5. super().__init__()
  6. self.loss_func = torch.nn.CTCLoss(blank=loss_cfg['blank_idx'], reduction=reduction, zero_infinity=True)
  7. def forward(self, pred, args):
  8. # print(pred)
  9. batch_size = pred.size(0)
  10. label, label_length = args['targets'], args['targets_lengths']
  11. pred = pred.log_softmax(2)
  12. pred = pred.permute(1, 0, 2)
  13. preds_lengths = torch.tensor([pred.size(0)] * batch_size, dtype=torch.long)
  14. loss = self.loss_func(pred, label, preds_lengths, label_length)
  15. return {'loss': loss}
  16. class EnhancedCTCLoss(nn.Module):
  17. def __init__(self,
  18. # use_focal_loss=False,
  19. # use_ace_loss=False,
  20. # ace_loss_weight=0.1,
  21. loss_cfg,
  22. use_center_loss=True,
  23. center_loss_weight=0.05,
  24. num_classes=6625,
  25. feat_dim=96,
  26. init_center=False,
  27. center_file_path=None,
  28. **kwargs):
  29. super(EnhancedCTCLoss, self).__init__()
  30. self.ctc_loss_func = CTCLoss(loss_cfg)
  31. # self.use_ace_loss = False
  32. # if use_ace_loss:
  33. # self.use_ace_loss = use_ace_loss
  34. # self.ace_loss_func = ACELoss()
  35. # self.ace_loss_weight = ace_loss_weight
  36. self.use_center_loss = False
  37. if use_center_loss:
  38. self.use_center_loss = use_center_loss
  39. self.center_loss_func = CenterLoss(
  40. num_classes=num_classes,
  41. feat_dim=feat_dim,
  42. init_center=init_center,
  43. center_file_path=center_file_path)
  44. self.center_loss_weight = center_loss_weight
  45. def forward(self, predicts, batch):
  46. loss = self.ctc_loss_func(predicts, batch)["loss"]
  47. if self.use_center_loss:
  48. center_loss = self.center_loss_func(
  49. predicts, batch)["loss_center"] * self.center_loss_weight
  50. loss = loss + center_loss
  51. # if self.use_ace_loss:
  52. # ace_loss = self.ace_loss_func(
  53. # predicts, batch)["loss_ace"] * self.ace_loss_weight
  54. # loss = loss + ace_loss
  55. return {'enhanced_ctc_loss': loss}
  56. class CenterLoss(nn.Module):
  57. """
  58. Reference: Wen et al. A Discriminative Feature Learning Approach for Deep Face Recognition. ECCV 2016.
  59. """
  60. def __init__(self, num_classes=6625, feat_dim=96, center_file_path=None):
  61. super().__init__()
  62. self.num_classes = num_classes
  63. self.feat_dim = feat_dim
  64. self.centers = nn.Parameter(torch.randn(self.num_classes, self.feat_dim))
  65. # if center_file_path is not None:
  66. # assert os.path.exists(
  67. # center_file_path
  68. # ), f"center path({center_file_path}) must exist when it is not None."
  69. # with open(center_file_path, 'rb') as f:
  70. # char_dict = pickle.load(f)
  71. # for key in char_dict.keys():
  72. # self.centers[key] = paddle.to_tensor(char_dict[key])
  73. def forward(self, predicts, batch):
  74. # assert isinstance(predicts, (list, tuple))
  75. # features, predicts = predicts
  76. predicts = predicts
  77. features = batch
  78. # feats_reshape = paddle.reshape(
  79. # features, [-1, features.shape[-1]]).astype("float64")
  80. # label = paddle.argmax(predicts, axis=2)
  81. # label = paddle.reshape(label, [label.shape[0] * label.shape[1]])
  82. #
  83. # batch_size = feats_reshape.shape[0]
  84. #
  85. # #calc l2 distance between feats and centers
  86. # square_feat = paddle.sum(paddle.square(feats_reshape),
  87. # axis=1,
  88. # keepdim=True)
  89. # square_feat = paddle.expand(square_feat, [batch_size, self.num_classes])
  90. #
  91. # square_center = paddle.sum(paddle.square(self.centers),
  92. # axis=1,
  93. # keepdim=True)
  94. # square_center = paddle.expand(
  95. # square_center, [self.num_classes, batch_size]).astype("float64")
  96. # square_center = paddle.transpose(square_center, [1, 0])
  97. #
  98. # distmat = paddle.add(square_feat, square_center)
  99. # feat_dot_center = paddle.matmul(feats_reshape,
  100. # paddle.transpose(self.centers, [1, 0]))
  101. # distmat = distmat - 2.0 * feat_dot_center
  102. x = predicts
  103. labels = features
  104. batch_size = x.size(0)
  105. distmat = torch.pow(x, 2).sum(dim=1, keepdim=True).expand(batch_size, self.num_classes) + \
  106. torch.pow(self.centers, 2).sum(dim=1, keepdim=True).expand(self.num_classes, batch_size).t()
  107. distmat.addmm_(1, -2, x, self.centers.t())
  108. #generate the mask
  109. # classes = torch.arange(self.num_classes).astype("int64")
  110. classes = torch.arange(self.num_classes).long()
  111. # label = paddle.expand(
  112. # paddle.unsqueeze(label, 1), (batch_size, self.num_classes))
  113. labels = labels.unsqueeze(1).expand(batch_size, self.num_classes)
  114. # mask = paddle.equal(
  115. # paddle.expand(classes, [batch_size, self.num_classes]),
  116. # label).astype("float64")
  117. mask = labels.eq(classes.expand(batch_size, self.num_classes))
  118. # dist = paddle.multiply(distmat, mask)
  119. dist = distmat * mask.float()
  120. # loss = paddle.sum(paddle.clip(dist, min=1e-12, max=1e+12)) / batch_size
  121. loss = dist.clamp(min=1e-12, max=1e+12).sum() / batch_size
  122. return {'loss_center': loss}