123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113 |
- import tensorflow as tf
- import keras.backend as K
- def ctc_lambda_func(args):
- """
- 定义ctc损失函数
- 参数:y_pred:预测值,labels:标签,input_length:lstm tiemstep,label_length:标签长度
- """
- y_pred, labels, input_length, label_length = args
- # return K.ctc_batch_cost(labels, y_pred, input_length, label_length)
- return my_ctc_batch_cost(labels, y_pred, input_length, label_length, mode=1)
- def my_ctc_batch_cost(y_true, y_pred, input_length, label_length, mode=0):
- """Runs CTC loss algorithm on each batch element.
- Args:
- y_true: tensor `(samples, max_string_length)`
- containing the truth labels.
- y_pred: tensor `(samples, time_steps, num_categories)`
- containing the prediction, or output of the softmax.
- input_length: tensor `(samples, 1)` containing the sequence length for
- each batch item in `y_pred`.
- label_length: tensor `(samples, 1)` containing the sequence length for
- each batch item in `y_true`.
- Returns:
- Tensor with shape (samples,1) containing the
- CTC loss of each element.
- """
- label_length = tf.cast(
- tf.squeeze(label_length, axis=-1), tf.int32)
- input_length = tf.cast(
- tf.squeeze(input_length, axis=-1), tf.int32)
- sparse_labels = tf.cast(
- K.ctc_label_dense_to_sparse(y_true, label_length), tf.int32)
- y_pred = tf.math.log(tf.compat.v1.transpose(y_pred, perm=[1, 0, 2]) + K.epsilon())
- loss = tf.compat.v1.nn.ctc_loss(inputs=y_pred,
- labels=sparse_labels,
- sequence_length=input_length,
- preprocess_collapse_repeated=False,
- ctc_merge_repeated=False)
- loss = tf.expand_dims(loss, 1)
- if mode == 1:
- loss = focal_ctc(sparse_labels, y_pred, input_length, loss)
- return loss
- def focal_ctc(targets, logits, seq_len, ctc_loss, alpha=0.5, gamma=2.0):
- # FOCAL LOSS
- # This function computes Focal Loss
- # Inputs: alpha, gamma, targets, logits, seq_len
- # Default Values: alpha=0.5 and gamma=2.0
- # Output: loss
- # ctc_loss = tf.compat.v1.nn.ctc_loss(labels=targets, inputs=logits, sequence_length=seq_len, time_major=True)
- p = tf.exp(-ctc_loss)
- # ((alpha)*((1-p)**gamma)*(ctc_loss))
- focal_ctc_loss = tf.multiply(tf.multiply(alpha, tf.pow((1-p), gamma)), ctc_loss)
- loss = tf.reduce_mean(focal_ctc_loss)
- return loss
- def perceptual_loss(gamma=2., alpha=.25):
- from model import Vgg19
- def perceptual_loss_fixed(y_true, y_pred):
- if globals().get("vgg") is None:
- vgg = Vgg19("./vgg19.npy")
- globals().update({"vgg": vgg})
- print("init vgg19 success!")
- else:
- vgg = globals().get("vgg")
- # mask_1 = tf.where(y_true[:, :, :, 0] >= 0.75, 1, 0)
- # mask_2 = tf.where(y_true[:, :, :, 1] >= 0.75, 1, 0)
- # mask_3 = tf.where(y_true[:, :, :, 2] >= 0.75, 1, 0)
- # mask_white = tf.expand_dims(mask_1 * mask_2 * mask_3, -1)
- # mask_white = tf.concat([mask_white, mask_white, mask_white], -1)
- # y_true_mask = tf.where(mask_white == 1, 1., y_true)
- # y_pred_mask = tf.where(mask_white == 1, 1., y_pred)
- # print("y_pred.shape", y_pred.shape)
- y_pred = tf.concat([y_pred, y_pred, y_pred], -1)
- y_true = tf.concat([y_true, y_true, y_true], -1)
- vgg.build(y_true)
- vgg_true_1 = vgg.conv1_1
- vgg_true_2 = vgg.conv2_1
- vgg_true_3 = vgg.conv3_1
- vgg_true_4 = vgg.conv4_1
- vgg_true_5 = vgg.conv5_1
- vgg.build(y_pred)
- vgg_pred_1 = vgg.conv1_1
- vgg_pred_2 = vgg.conv2_1
- vgg_pred_3 = vgg.conv3_1
- vgg_pred_4 = vgg.conv4_1
- vgg_pred_5 = vgg.conv5_1
- loss_0 = l2_focal_loss(threshold=0.2, ratio=1000, reverse=True)(y_true, y_pred)
- loss_1 = l2_focal_loss(threshold=0.2, ratio=1000, reverse=True)(vgg_true_1, vgg_pred_1)
- loss_2 = l2_focal_loss(threshold=0.2, ratio=1000, reverse=True)(vgg_true_2, vgg_pred_2)
- loss_3 = l2_focal_loss(threshold=0.2, ratio=1000, reverse=True)(vgg_true_3, vgg_pred_3)
- loss_4 = l2_focal_loss(threshold=0.2, ratio=1000, reverse=True)(vgg_true_4, vgg_pred_4)
- loss_5 = l2_focal_loss(threshold=0.2, ratio=1000, reverse=True)(vgg_true_5, vgg_pred_5)
- return (loss_0+loss_1+loss_2+loss_3+loss_4+loss_5) / 6
- return perceptual_loss_fixed
|