123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104 |
- import torch
- from torch import nn
- import numpy as np
- class PSELoss(nn.Module):
- def __init__(self, Lambda, ratio=3, reduction='mean'):
- """Implement PSE Loss.
- """
- super(PSELoss, self).__init__()
- assert reduction in ['mean', 'sum'], " reduction must in ['mean','sum']"
- self.Lambda = Lambda
- self.ratio = ratio
- self.reduction = reduction
- def forward(self, outputs, labels, training_masks):
- texts = outputs[:, -1, :, :]
- kernels = outputs[:, :-1, :, :]
- gt_texts = labels[:, -1, :, :]
- gt_kernels = labels[:, :-1, :, :]
- selected_masks = self.ohem_batch(texts, gt_texts, training_masks)
- selected_masks = selected_masks.to(outputs.device)
- loss_text = self.dice_loss(texts, gt_texts, selected_masks)
- loss_kernels = []
- mask0 = torch.sigmoid(texts).data.cpu().numpy()
- mask1 = training_masks.data.cpu().numpy()
- selected_masks = ((mask0 > 0.5) & (mask1 > 0.5)).astype('float32')
- selected_masks = torch.from_numpy(selected_masks).float()
- selected_masks = selected_masks.to(outputs.device)
- kernels_num = gt_kernels.size()[1]
- for i in range(kernels_num):
- kernel_i = kernels[:, i, :, :]
- gt_kernel_i = gt_kernels[:, i, :, :]
- loss_kernel_i = self.dice_loss(kernel_i, gt_kernel_i, selected_masks)
- loss_kernels.append(loss_kernel_i)
- loss_kernels = torch.stack(loss_kernels).mean(0)
- if self.reduction == 'mean':
- loss_text = loss_text.mean()
- loss_kernels = loss_kernels.mean()
- elif self.reduction == 'sum':
- loss_text = loss_text.sum()
- loss_kernels = loss_kernels.sum()
- loss = self.Lambda * loss_text + (1 - self.Lambda) * loss_kernels
- return loss_text, loss_kernels, loss
- def dice_loss(self, input, target, mask):
- input = torch.sigmoid(input)
- input = input.contiguous().view(input.size()[0], -1)
- target = target.contiguous().view(target.size()[0], -1)
- mask = mask.contiguous().view(mask.size()[0], -1)
- input = input * mask
- target = target * mask
- a = torch.sum(input * target, 1)
- b = torch.sum(input * input, 1) + 0.001
- c = torch.sum(target * target, 1) + 0.001
- d = (2 * a) / (b + c)
- return 1 - d
- def ohem_single(self, score, gt_text, training_mask):
- pos_num = (int)(np.sum(gt_text > 0.5)) - (int)(np.sum((gt_text > 0.5) & (training_mask <= 0.5)))
- if pos_num == 0:
- # selected_mask = gt_text.copy() * 0 # may be not good
- selected_mask = training_mask
- selected_mask = selected_mask.reshape(1, selected_mask.shape[0], selected_mask.shape[1]).astype('float32')
- return selected_mask
- neg_num = (int)(np.sum(gt_text <= 0.5))
- neg_num = (int)(min(pos_num * 3, neg_num))
- if neg_num == 0:
- selected_mask = training_mask
- selected_mask = selected_mask.reshape(1, selected_mask.shape[0], selected_mask.shape[1]).astype('float32')
- return selected_mask
- neg_score = score[gt_text <= 0.5]
- # 将负样本得分从高到低排序
- neg_score_sorted = np.sort(-neg_score)
- threshold = -neg_score_sorted[neg_num - 1]
- # 选出 得分高的 负样本 和正样本 的 mask
- selected_mask = ((score >= threshold) | (gt_text > 0.5)) & (training_mask > 0.5)
- selected_mask = selected_mask.reshape(1, selected_mask.shape[0], selected_mask.shape[1]).astype('float32')
- return selected_mask
- def ohem_batch(self, scores, gt_texts, training_masks):
- scores = scores.data.cpu().numpy()
- gt_texts = gt_texts.data.cpu().numpy()
- training_masks = training_masks.data.cpu().numpy()
- selected_masks = []
- for i in range(scores.shape[0]):
- selected_masks.append(self.ohem_single(scores[i, :, :], gt_texts[i, :, :], training_masks[i, :, :]))
- selected_masks = np.concatenate(selected_masks, 0)
- selected_masks = torch.from_numpy(selected_masks).float()
- return selected_masks
|