CTCLoss_test.py 6.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165
  1. from torch import nn
  2. import torch
  3. class CTCLoss(nn.Module):
  4. def __init__(self, loss_cfg, reduction='mean',use_focal_loss=False):
  5. super().__init__()
  6. if use_focal_loss:
  7. reduction = 'none'
  8. self.loss_func = torch.nn.CTCLoss(blank=loss_cfg['blank_idx'], reduction=reduction, zero_infinity=True)
  9. self.use_focal_loss = use_focal_loss
  10. def forward(self, pred, args):
  11. pred = pred[1]
  12. batch_size = pred.size(0)
  13. label, label_length = args['targets'], args['targets_lengths']
  14. pred = pred.log_softmax(2)
  15. pred = pred.permute(1, 0, 2)
  16. preds_lengths = torch.tensor([pred.size(0)] * batch_size, dtype=torch.long)
  17. loss = self.loss_func(pred, label, preds_lengths, label_length)
  18. if self.use_focal_loss:
  19. weight = torch.exp(-loss)
  20. weight = torch.subtract(torch.as_tensor(1.0), weight)
  21. weight = torch.square(weight)
  22. loss = torch.multiply(loss , weight)
  23. loss = loss.mean()
  24. return {'loss': loss}
  25. class EnhancedCTCLoss(nn.Module):
  26. def __init__(self,
  27. # use_focal_loss=False,
  28. # use_ace_loss=False,
  29. # ace_loss_weight=0.1,
  30. loss_cfg,
  31. use_center_loss=True,
  32. center_loss_weight=0.05,
  33. num_classes=5990,
  34. feat_dim=96,
  35. init_center=False,
  36. center_file_path=None,
  37. **kwargs):
  38. super(EnhancedCTCLoss, self).__init__()
  39. self.ctc_loss_func = CTCLoss(loss_cfg)
  40. # self.use_ace_loss = False
  41. # if use_ace_loss:
  42. # self.use_ace_loss = use_ace_loss
  43. # self.ace_loss_func = ACELoss()
  44. # self.ace_loss_weight = ace_loss_weight
  45. self.use_center_loss = False
  46. if use_center_loss:
  47. self.use_center_loss = use_center_loss
  48. self.center_loss_func = CenterLoss(
  49. num_classes=num_classes,
  50. feat_dim=feat_dim,
  51. center_file_path=center_file_path)
  52. self.center_loss_weight = center_loss_weight
  53. def forward(self, predicts,args):
  54. loss = self.ctc_loss_func(predicts, args)["loss"]
  55. if self.use_center_loss:
  56. center_loss = self.center_loss_func(
  57. predicts)["loss_center"] * self.center_loss_weight
  58. loss = loss + center_loss
  59. # if self.use_ace_loss:
  60. # ace_loss = self.ace_loss_func(
  61. # predicts, batch)["loss_ace"] * self.ace_loss_weight
  62. # loss = loss + ace_loss
  63. return {'loss': loss}
  64. class CenterLoss(nn.Module):
  65. """
  66. Reference: Wen et al. A Discriminative Feature Learning Approach for Deep Face Recognition. ECCV 2016.
  67. """
  68. def __init__(self, num_classes=6625, feat_dim=96, center_file_path=None):
  69. super().__init__()
  70. self.num_classes = num_classes
  71. self.feat_dim = feat_dim
  72. self.centers = nn.Parameter(torch.randn(self.num_classes, self.feat_dim))
  73. self.use_gpu = True if torch.cuda.is_available() else False
  74. # if center_file_path is not None:
  75. # assert os.path.exists(
  76. # center_file_path
  77. # ), f"center path({center_file_path}) must exist when it is not None."
  78. # with open(center_file_path, 'rb') as f:
  79. # char_dict = pickle.load(f)
  80. # for key in char_dict.keys():
  81. # self.centers[key] = paddle.to_tensor(char_dict[key])
  82. def forward(self, predicts):
  83. # assert isinstance(predicts, (list, tuple))
  84. features, predicts = predicts
  85. # predicts = predicts
  86. # features = batch
  87. # batch_size = features.size(0)
  88. label = predicts.argmax( axis=2)
  89. # feats_reshape = paddle.reshape(
  90. # features, [-1, features.shape[-1]]).astype("float64")
  91. # feats_reshape = features.reshape(-1, features.size[-1])
  92. feats_reshape = torch.reshape(features,(-1, features.size(-1)))
  93. # label = paddle.argmax(predicts, axis=2)
  94. # label = features.argmax( axis=2)
  95. label = label.reshape(label.size(0) * label.size(1))
  96. batch_size = feats_reshape.size(0)
  97. #calc l2 distance between feats and centers
  98. # square_feat = paddle.sum(paddle.square(feats_reshape),
  99. # axis=1,
  100. # keepdim=True)
  101. # square_feat = paddle.expand(square_feat, [batch_size, self.num_classes])
  102. square_feat = torch.pow(feats_reshape, 2).sum(dim=1, keepdim=True).expand(batch_size, self.num_classes)
  103. # square_center = paddle.sum(paddle.square(self.centers),
  104. # axis=1,
  105. # keepdim=True)
  106. # square_center = paddle.expand(
  107. # square_center, [self.num_classes, batch_size]).astype("float64")
  108. # square_center = paddle.transpose(square_center, [1, 0])
  109. square_center = torch.pow(self.centers, 2).sum(dim=1, keepdim=True).expand(self.num_classes, batch_size).t()
  110. # distmat = paddle.add(square_feat, square_center)
  111. distmat = square_feat + square_center
  112. # feat_dot_center = paddle.matmul(feats_reshape,
  113. # paddle.transpose(self.centers, [1, 0]))
  114. # distmat = distmat - 2.0 * feat_dot_center
  115. distmat.addmm_(1, -2, feats_reshape, self.centers.t())
  116. # x = predicts
  117. # labels = features
  118. # batch_size = x.size(0)
  119. # distmat = torch.pow(x, 2).sum(dim=1, keepdim=True).expand(batch_size, self.num_classes) + \
  120. # torch.pow(self.centers, 2).sum(dim=1, keepdim=True).expand(self.num_classes, batch_size).t()
  121. # distmat.addmm_(1, -2, x, self.centers.t())
  122. #generate the mask
  123. # classes = torch.arange(self.num_classes).astype("int64")
  124. classes = torch.arange(self.num_classes).long()
  125. if self.use_gpu:
  126. classes = classes.cuda()
  127. # label = paddle.expand(
  128. # paddle.unsqueeze(label, 1), (batch_size, self.num_classes))
  129. labels = label.unsqueeze(1).expand(batch_size, self.num_classes)
  130. # mask = paddle.equal(
  131. # paddle.expand(classes, [batch_size, self.num_classes]),
  132. # label).astype("float64")
  133. mask = labels.eq(classes.expand(batch_size, self.num_classes))
  134. # dist = paddle.multiply(distmat, mask)
  135. dist = distmat * mask.float()
  136. # loss = paddle.sum(paddle.clip(dist, min=1e-12, max=1e+12)) / batch_size
  137. loss = dist.clamp(min=1e-12, max=1e+12).sum() / batch_size
  138. return {'loss_center': loss}