from torch import nn import torch class CTCLoss(nn.Module): def __init__(self, loss_cfg, reduction='mean',use_focal_loss=False): super().__init__() if use_focal_loss: reduction = 'none' self.loss_func = torch.nn.CTCLoss(blank=loss_cfg['blank_idx'], reduction=reduction, zero_infinity=True) self.use_focal_loss = use_focal_loss def forward(self, pred, args): pred = pred[1] batch_size = pred.size(0) label, label_length = args['targets'], args['targets_lengths'] pred = pred.log_softmax(2) pred = pred.permute(1, 0, 2) preds_lengths = torch.tensor([pred.size(0)] * batch_size, dtype=torch.long) loss = self.loss_func(pred, label, preds_lengths, label_length) if self.use_focal_loss: weight = torch.exp(-loss) weight = torch.subtract(torch.as_tensor(1.0), weight) weight = torch.square(weight) loss = torch.multiply(loss , weight) loss = loss.mean() return {'loss': loss} class EnhancedCTCLoss(nn.Module): def __init__(self, # use_focal_loss=False, # use_ace_loss=False, # ace_loss_weight=0.1, loss_cfg, use_center_loss=True, center_loss_weight=0.05, num_classes=5990, feat_dim=96, init_center=False, center_file_path=None, **kwargs): super(EnhancedCTCLoss, self).__init__() self.ctc_loss_func = CTCLoss(loss_cfg) # self.use_ace_loss = False # if use_ace_loss: # self.use_ace_loss = use_ace_loss # self.ace_loss_func = ACELoss() # self.ace_loss_weight = ace_loss_weight self.use_center_loss = False if use_center_loss: self.use_center_loss = use_center_loss self.center_loss_func = CenterLoss( num_classes=num_classes, feat_dim=feat_dim, center_file_path=center_file_path) self.center_loss_weight = center_loss_weight def forward(self, predicts,args): loss = self.ctc_loss_func(predicts, args)["loss"] if self.use_center_loss: center_loss = self.center_loss_func( predicts)["loss_center"] * self.center_loss_weight loss = loss + center_loss # if self.use_ace_loss: # ace_loss = self.ace_loss_func( # predicts, batch)["loss_ace"] * self.ace_loss_weight # loss = loss + ace_loss return {'loss': loss} class CenterLoss(nn.Module): """ Reference: Wen et al. A Discriminative Feature Learning Approach for Deep Face Recognition. ECCV 2016. """ def __init__(self, num_classes=6625, feat_dim=96, center_file_path=None): super().__init__() self.num_classes = num_classes self.feat_dim = feat_dim self.centers = nn.Parameter(torch.randn(self.num_classes, self.feat_dim)) self.use_gpu = True if torch.cuda.is_available() else False # if center_file_path is not None: # assert os.path.exists( # center_file_path # ), f"center path({center_file_path}) must exist when it is not None." # with open(center_file_path, 'rb') as f: # char_dict = pickle.load(f) # for key in char_dict.keys(): # self.centers[key] = paddle.to_tensor(char_dict[key]) def forward(self, predicts): # assert isinstance(predicts, (list, tuple)) features, predicts = predicts # predicts = predicts # features = batch # batch_size = features.size(0) label = predicts.argmax( axis=2) # feats_reshape = paddle.reshape( # features, [-1, features.shape[-1]]).astype("float64") # feats_reshape = features.reshape(-1, features.size[-1]) feats_reshape = torch.reshape(features,(-1, features.size(-1))) # label = paddle.argmax(predicts, axis=2) # label = features.argmax( axis=2) label = label.reshape(label.size(0) * label.size(1)) batch_size = feats_reshape.size(0) #calc l2 distance between feats and centers # square_feat = paddle.sum(paddle.square(feats_reshape), # axis=1, # keepdim=True) # square_feat = paddle.expand(square_feat, [batch_size, self.num_classes]) square_feat = torch.pow(feats_reshape, 2).sum(dim=1, keepdim=True).expand(batch_size, self.num_classes) # square_center = paddle.sum(paddle.square(self.centers), # axis=1, # keepdim=True) # square_center = paddle.expand( # square_center, [self.num_classes, batch_size]).astype("float64") # square_center = paddle.transpose(square_center, [1, 0]) square_center = torch.pow(self.centers, 2).sum(dim=1, keepdim=True).expand(self.num_classes, batch_size).t() # distmat = paddle.add(square_feat, square_center) distmat = square_feat + square_center # feat_dot_center = paddle.matmul(feats_reshape, # paddle.transpose(self.centers, [1, 0])) # distmat = distmat - 2.0 * feat_dot_center distmat.addmm_(1, -2, feats_reshape, self.centers.t()) # x = predicts # labels = features # batch_size = x.size(0) # distmat = torch.pow(x, 2).sum(dim=1, keepdim=True).expand(batch_size, self.num_classes) + \ # torch.pow(self.centers, 2).sum(dim=1, keepdim=True).expand(self.num_classes, batch_size).t() # distmat.addmm_(1, -2, x, self.centers.t()) #generate the mask # classes = torch.arange(self.num_classes).astype("int64") classes = torch.arange(self.num_classes).long() if self.use_gpu: classes = classes.cuda() # label = paddle.expand( # paddle.unsqueeze(label, 1), (batch_size, self.num_classes)) labels = label.unsqueeze(1).expand(batch_size, self.num_classes) # mask = paddle.equal( # paddle.expand(classes, [batch_size, self.num_classes]), # label).astype("float64") mask = labels.eq(classes.expand(batch_size, self.num_classes)) # dist = paddle.multiply(distmat, mask) dist = distmat * mask.float() # loss = paddle.sum(paddle.clip(dist, min=1e-12, max=1e+12)) / batch_size loss = dist.clamp(min=1e-12, max=1e+12).sum() / batch_size return {'loss_center': loss}