123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165 |
- from torch import nn
- import torch
- class CTCLoss(nn.Module):
- def __init__(self, loss_cfg, reduction='mean',use_focal_loss=False):
- super().__init__()
- if use_focal_loss:
- reduction = 'none'
- self.loss_func = torch.nn.CTCLoss(blank=loss_cfg['blank_idx'], reduction=reduction, zero_infinity=True)
- self.use_focal_loss = use_focal_loss
- def forward(self, pred, args):
- pred = pred[1]
- batch_size = pred.size(0)
- label, label_length = args['targets'], args['targets_lengths']
- pred = pred.log_softmax(2)
- pred = pred.permute(1, 0, 2)
- preds_lengths = torch.tensor([pred.size(0)] * batch_size, dtype=torch.long)
- loss = self.loss_func(pred, label, preds_lengths, label_length)
- if self.use_focal_loss:
- weight = torch.exp(-loss)
- weight = torch.subtract(torch.as_tensor(1.0), weight)
- weight = torch.square(weight)
- loss = torch.multiply(loss , weight)
- loss = loss.mean()
- return {'loss': loss}
- class EnhancedCTCLoss(nn.Module):
- def __init__(self,
- # use_focal_loss=False,
- # use_ace_loss=False,
- # ace_loss_weight=0.1,
- loss_cfg,
- use_center_loss=True,
- center_loss_weight=0.05,
- num_classes=5990,
- feat_dim=96,
- init_center=False,
- center_file_path=None,
- **kwargs):
- super(EnhancedCTCLoss, self).__init__()
- self.ctc_loss_func = CTCLoss(loss_cfg)
- # self.use_ace_loss = False
- # if use_ace_loss:
- # self.use_ace_loss = use_ace_loss
- # self.ace_loss_func = ACELoss()
- # self.ace_loss_weight = ace_loss_weight
- self.use_center_loss = False
- if use_center_loss:
- self.use_center_loss = use_center_loss
- self.center_loss_func = CenterLoss(
- num_classes=num_classes,
- feat_dim=feat_dim,
- center_file_path=center_file_path)
- self.center_loss_weight = center_loss_weight
- def forward(self, predicts,args):
- loss = self.ctc_loss_func(predicts, args)["loss"]
- if self.use_center_loss:
- center_loss = self.center_loss_func(
- predicts)["loss_center"] * self.center_loss_weight
- loss = loss + center_loss
- # if self.use_ace_loss:
- # ace_loss = self.ace_loss_func(
- # predicts, batch)["loss_ace"] * self.ace_loss_weight
- # loss = loss + ace_loss
- return {'loss': loss}
- class CenterLoss(nn.Module):
- """
- Reference: Wen et al. A Discriminative Feature Learning Approach for Deep Face Recognition. ECCV 2016.
- """
- def __init__(self, num_classes=6625, feat_dim=96, center_file_path=None):
- super().__init__()
- self.num_classes = num_classes
- self.feat_dim = feat_dim
- self.centers = nn.Parameter(torch.randn(self.num_classes, self.feat_dim))
- self.use_gpu = True if torch.cuda.is_available() else False
- # if center_file_path is not None:
- # assert os.path.exists(
- # center_file_path
- # ), f"center path({center_file_path}) must exist when it is not None."
- # with open(center_file_path, 'rb') as f:
- # char_dict = pickle.load(f)
- # for key in char_dict.keys():
- # self.centers[key] = paddle.to_tensor(char_dict[key])
- def forward(self, predicts):
- # assert isinstance(predicts, (list, tuple))
- features, predicts = predicts
- # predicts = predicts
- # features = batch
- # batch_size = features.size(0)
- label = predicts.argmax( axis=2)
- # feats_reshape = paddle.reshape(
- # features, [-1, features.shape[-1]]).astype("float64")
- # feats_reshape = features.reshape(-1, features.size[-1])
- feats_reshape = torch.reshape(features,(-1, features.size(-1)))
- # label = paddle.argmax(predicts, axis=2)
- # label = features.argmax( axis=2)
- label = label.reshape(label.size(0) * label.size(1))
- batch_size = feats_reshape.size(0)
- #calc l2 distance between feats and centers
- # square_feat = paddle.sum(paddle.square(feats_reshape),
- # axis=1,
- # keepdim=True)
- # square_feat = paddle.expand(square_feat, [batch_size, self.num_classes])
- square_feat = torch.pow(feats_reshape, 2).sum(dim=1, keepdim=True).expand(batch_size, self.num_classes)
- # square_center = paddle.sum(paddle.square(self.centers),
- # axis=1,
- # keepdim=True)
- # square_center = paddle.expand(
- # square_center, [self.num_classes, batch_size]).astype("float64")
- # square_center = paddle.transpose(square_center, [1, 0])
- square_center = torch.pow(self.centers, 2).sum(dim=1, keepdim=True).expand(self.num_classes, batch_size).t()
- # distmat = paddle.add(square_feat, square_center)
- distmat = square_feat + square_center
- # feat_dot_center = paddle.matmul(feats_reshape,
- # paddle.transpose(self.centers, [1, 0]))
- # distmat = distmat - 2.0 * feat_dot_center
- distmat.addmm_(1, -2, feats_reshape, self.centers.t())
- # x = predicts
- # labels = features
- # batch_size = x.size(0)
- # distmat = torch.pow(x, 2).sum(dim=1, keepdim=True).expand(batch_size, self.num_classes) + \
- # torch.pow(self.centers, 2).sum(dim=1, keepdim=True).expand(self.num_classes, batch_size).t()
- # distmat.addmm_(1, -2, x, self.centers.t())
- #generate the mask
- # classes = torch.arange(self.num_classes).astype("int64")
- classes = torch.arange(self.num_classes).long()
- if self.use_gpu:
- classes = classes.cuda()
- # label = paddle.expand(
- # paddle.unsqueeze(label, 1), (batch_size, self.num_classes))
- labels = label.unsqueeze(1).expand(batch_size, self.num_classes)
- # mask = paddle.equal(
- # paddle.expand(classes, [batch_size, self.num_classes]),
- # label).astype("float64")
- mask = labels.eq(classes.expand(batch_size, self.num_classes))
- # dist = paddle.multiply(distmat, mask)
- dist = distmat * mask.float()
- # loss = paddle.sum(paddle.clip(dist, min=1e-12, max=1e+12)) / batch_size
- loss = dist.clamp(min=1e-12, max=1e+12).sum() / batch_size
- return {'loss_center': loss}
|