123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143 |
- from torch import nn
- import torch
- class CTCLoss(nn.Module):
- def __init__(self, loss_cfg, reduction='mean'):
- super().__init__()
- self.loss_func = torch.nn.CTCLoss(blank=loss_cfg['blank_idx'], reduction=reduction, zero_infinity=True)
- def forward(self, pred, args):
- # print(pred)
- batch_size = pred.size(0)
- label, label_length = args['targets'], args['targets_lengths']
- pred = pred.log_softmax(2)
- pred = pred.permute(1, 0, 2)
- preds_lengths = torch.tensor([pred.size(0)] * batch_size, dtype=torch.long)
- loss = self.loss_func(pred, label, preds_lengths, label_length)
- return {'loss': loss}
- class EnhancedCTCLoss(nn.Module):
- def __init__(self,
- # use_focal_loss=False,
- # use_ace_loss=False,
- # ace_loss_weight=0.1,
- loss_cfg,
- use_center_loss=True,
- center_loss_weight=0.05,
- num_classes=6625,
- feat_dim=96,
- init_center=False,
- center_file_path=None,
- **kwargs):
- super(EnhancedCTCLoss, self).__init__()
- self.ctc_loss_func = CTCLoss(loss_cfg)
- # self.use_ace_loss = False
- # if use_ace_loss:
- # self.use_ace_loss = use_ace_loss
- # self.ace_loss_func = ACELoss()
- # self.ace_loss_weight = ace_loss_weight
- self.use_center_loss = False
- if use_center_loss:
- self.use_center_loss = use_center_loss
- self.center_loss_func = CenterLoss(
- num_classes=num_classes,
- feat_dim=feat_dim,
- init_center=init_center,
- center_file_path=center_file_path)
- self.center_loss_weight = center_loss_weight
- def forward(self, predicts, batch):
- loss = self.ctc_loss_func(predicts, batch)["loss"]
- if self.use_center_loss:
- center_loss = self.center_loss_func(
- predicts, batch)["loss_center"] * self.center_loss_weight
- loss = loss + center_loss
- # if self.use_ace_loss:
- # ace_loss = self.ace_loss_func(
- # predicts, batch)["loss_ace"] * self.ace_loss_weight
- # loss = loss + ace_loss
- return {'enhanced_ctc_loss': loss}
- class CenterLoss(nn.Module):
- """
- Reference: Wen et al. A Discriminative Feature Learning Approach for Deep Face Recognition. ECCV 2016.
- """
- def __init__(self, num_classes=6625, feat_dim=96, center_file_path=None):
- super().__init__()
- self.num_classes = num_classes
- self.feat_dim = feat_dim
- self.centers = nn.Parameter(torch.randn(self.num_classes, self.feat_dim))
- # if center_file_path is not None:
- # assert os.path.exists(
- # center_file_path
- # ), f"center path({center_file_path}) must exist when it is not None."
- # with open(center_file_path, 'rb') as f:
- # char_dict = pickle.load(f)
- # for key in char_dict.keys():
- # self.centers[key] = paddle.to_tensor(char_dict[key])
- def forward(self, predicts, batch):
- # assert isinstance(predicts, (list, tuple))
- # features, predicts = predicts
- predicts = predicts
- features = batch
- # feats_reshape = paddle.reshape(
- # features, [-1, features.shape[-1]]).astype("float64")
- # label = paddle.argmax(predicts, axis=2)
- # label = paddle.reshape(label, [label.shape[0] * label.shape[1]])
- #
- # batch_size = feats_reshape.shape[0]
- #
- # #calc l2 distance between feats and centers
- # square_feat = paddle.sum(paddle.square(feats_reshape),
- # axis=1,
- # keepdim=True)
- # square_feat = paddle.expand(square_feat, [batch_size, self.num_classes])
- #
- # square_center = paddle.sum(paddle.square(self.centers),
- # axis=1,
- # keepdim=True)
- # square_center = paddle.expand(
- # square_center, [self.num_classes, batch_size]).astype("float64")
- # square_center = paddle.transpose(square_center, [1, 0])
- #
- # distmat = paddle.add(square_feat, square_center)
- # feat_dot_center = paddle.matmul(feats_reshape,
- # paddle.transpose(self.centers, [1, 0]))
- # distmat = distmat - 2.0 * feat_dot_center
- x = predicts
- labels = features
- batch_size = x.size(0)
- distmat = torch.pow(x, 2).sum(dim=1, keepdim=True).expand(batch_size, self.num_classes) + \
- torch.pow(self.centers, 2).sum(dim=1, keepdim=True).expand(self.num_classes, batch_size).t()
- distmat.addmm_(1, -2, x, self.centers.t())
- #generate the mask
- # classes = torch.arange(self.num_classes).astype("int64")
- classes = torch.arange(self.num_classes).long()
- # label = paddle.expand(
- # paddle.unsqueeze(label, 1), (batch_size, self.num_classes))
- labels = labels.unsqueeze(1).expand(batch_size, self.num_classes)
- # mask = paddle.equal(
- # paddle.expand(classes, [batch_size, self.num_classes]),
- # label).astype("float64")
- mask = labels.eq(classes.expand(batch_size, self.num_classes))
- # dist = paddle.multiply(distmat, mask)
- dist = distmat * mask.float()
- # loss = paddle.sum(paddle.clip(dist, min=1e-12, max=1e+12)) / batch_size
- loss = dist.clamp(min=1e-12, max=1e+12).sum() / batch_size
- return {'loss_center': loss}
|