loss.py 4.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113
  1. import tensorflow as tf
  2. import keras.backend as K
  3. def ctc_lambda_func(args):
  4. """
  5. 定义ctc损失函数
  6. 参数:y_pred:预测值,labels:标签,input_length:lstm tiemstep,label_length:标签长度
  7. """
  8. y_pred, labels, input_length, label_length = args
  9. # return K.ctc_batch_cost(labels, y_pred, input_length, label_length)
  10. return my_ctc_batch_cost(labels, y_pred, input_length, label_length, mode=1)
  11. def my_ctc_batch_cost(y_true, y_pred, input_length, label_length, mode=0):
  12. """Runs CTC loss algorithm on each batch element.
  13. Args:
  14. y_true: tensor `(samples, max_string_length)`
  15. containing the truth labels.
  16. y_pred: tensor `(samples, time_steps, num_categories)`
  17. containing the prediction, or output of the softmax.
  18. input_length: tensor `(samples, 1)` containing the sequence length for
  19. each batch item in `y_pred`.
  20. label_length: tensor `(samples, 1)` containing the sequence length for
  21. each batch item in `y_true`.
  22. Returns:
  23. Tensor with shape (samples,1) containing the
  24. CTC loss of each element.
  25. """
  26. label_length = tf.cast(
  27. tf.squeeze(label_length, axis=-1), tf.int32)
  28. input_length = tf.cast(
  29. tf.squeeze(input_length, axis=-1), tf.int32)
  30. sparse_labels = tf.cast(
  31. K.ctc_label_dense_to_sparse(y_true, label_length), tf.int32)
  32. y_pred = tf.math.log(tf.compat.v1.transpose(y_pred, perm=[1, 0, 2]) + K.epsilon())
  33. loss = tf.compat.v1.nn.ctc_loss(inputs=y_pred,
  34. labels=sparse_labels,
  35. sequence_length=input_length,
  36. preprocess_collapse_repeated=False,
  37. ctc_merge_repeated=False)
  38. loss = tf.expand_dims(loss, 1)
  39. if mode == 1:
  40. loss = focal_ctc(sparse_labels, y_pred, input_length, loss)
  41. return loss
  42. def focal_ctc(targets, logits, seq_len, ctc_loss, alpha=0.5, gamma=2.0):
  43. # FOCAL LOSS
  44. # This function computes Focal Loss
  45. # Inputs: alpha, gamma, targets, logits, seq_len
  46. # Default Values: alpha=0.5 and gamma=2.0
  47. # Output: loss
  48. # ctc_loss = tf.compat.v1.nn.ctc_loss(labels=targets, inputs=logits, sequence_length=seq_len, time_major=True)
  49. p = tf.exp(-ctc_loss)
  50. # ((alpha)*((1-p)**gamma)*(ctc_loss))
  51. focal_ctc_loss = tf.multiply(tf.multiply(alpha, tf.pow((1-p), gamma)), ctc_loss)
  52. loss = tf.reduce_mean(focal_ctc_loss)
  53. return loss
  54. def perceptual_loss(gamma=2., alpha=.25):
  55. from model import Vgg19
  56. def perceptual_loss_fixed(y_true, y_pred):
  57. if globals().get("vgg") is None:
  58. vgg = Vgg19("./vgg19.npy")
  59. globals().update({"vgg": vgg})
  60. print("init vgg19 success!")
  61. else:
  62. vgg = globals().get("vgg")
  63. # mask_1 = tf.where(y_true[:, :, :, 0] >= 0.75, 1, 0)
  64. # mask_2 = tf.where(y_true[:, :, :, 1] >= 0.75, 1, 0)
  65. # mask_3 = tf.where(y_true[:, :, :, 2] >= 0.75, 1, 0)
  66. # mask_white = tf.expand_dims(mask_1 * mask_2 * mask_3, -1)
  67. # mask_white = tf.concat([mask_white, mask_white, mask_white], -1)
  68. # y_true_mask = tf.where(mask_white == 1, 1., y_true)
  69. # y_pred_mask = tf.where(mask_white == 1, 1., y_pred)
  70. # print("y_pred.shape", y_pred.shape)
  71. y_pred = tf.concat([y_pred, y_pred, y_pred], -1)
  72. y_true = tf.concat([y_true, y_true, y_true], -1)
  73. vgg.build(y_true)
  74. vgg_true_1 = vgg.conv1_1
  75. vgg_true_2 = vgg.conv2_1
  76. vgg_true_3 = vgg.conv3_1
  77. vgg_true_4 = vgg.conv4_1
  78. vgg_true_5 = vgg.conv5_1
  79. vgg.build(y_pred)
  80. vgg_pred_1 = vgg.conv1_1
  81. vgg_pred_2 = vgg.conv2_1
  82. vgg_pred_3 = vgg.conv3_1
  83. vgg_pred_4 = vgg.conv4_1
  84. vgg_pred_5 = vgg.conv5_1
  85. loss_0 = l2_focal_loss(threshold=0.2, ratio=1000, reverse=True)(y_true, y_pred)
  86. loss_1 = l2_focal_loss(threshold=0.2, ratio=1000, reverse=True)(vgg_true_1, vgg_pred_1)
  87. loss_2 = l2_focal_loss(threshold=0.2, ratio=1000, reverse=True)(vgg_true_2, vgg_pred_2)
  88. loss_3 = l2_focal_loss(threshold=0.2, ratio=1000, reverse=True)(vgg_true_3, vgg_pred_3)
  89. loss_4 = l2_focal_loss(threshold=0.2, ratio=1000, reverse=True)(vgg_true_4, vgg_pred_4)
  90. loss_5 = l2_focal_loss(threshold=0.2, ratio=1000, reverse=True)(vgg_true_5, vgg_pred_5)
  91. return (loss_0+loss_1+loss_2+loss_3+loss_4+loss_5) / 6
  92. return perceptual_loss_fixed