loss.py 6.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154
  1. import tensorflow as tf
  2. import keras.backend as K
  3. def ctc_lambda_func(args):
  4. """
  5. 定义ctc损失函数
  6. 参数:y_pred:预测值,labels:标签,input_length:lstm tiemstep,label_length:标签长度
  7. """
  8. y_pred, labels, input_length, label_length = args
  9. # return K.ctc_batch_cost(labels, y_pred, input_length, label_length)
  10. return my_ctc_batch_cost(labels, y_pred, input_length, label_length, mode=1)
  11. def my_ctc_batch_cost(y_true, y_pred, input_length, label_length, mode=0):
  12. """Runs CTC loss algorithm on each batch element.
  13. Args:
  14. y_true: tensor `(samples, max_string_length)`
  15. containing the truth labels.
  16. y_pred: tensor `(samples, time_steps, num_categories)`
  17. containing the prediction, or output of the softmax.
  18. input_length: tensor `(samples, 1)` containing the sequence length for
  19. each batch item in `y_pred`.
  20. label_length: tensor `(samples, 1)` containing the sequence length for
  21. each batch item in `y_true`.
  22. Returns:
  23. Tensor with shape (samples,1) containing the
  24. CTC loss of each element.
  25. """
  26. label_length = tf.cast(
  27. tf.squeeze(label_length, axis=-1), tf.int32)
  28. input_length = tf.cast(
  29. tf.squeeze(input_length, axis=-1), tf.int32)
  30. sparse_labels = tf.cast(
  31. K.ctc_label_dense_to_sparse(y_true, label_length), tf.int32)
  32. y_pred = tf.math.log(tf.compat.v1.transpose(y_pred, perm=[1, 0, 2]) + K.epsilon())
  33. loss = tf.compat.v1.nn.ctc_loss(inputs=y_pred,
  34. labels=sparse_labels,
  35. sequence_length=input_length,
  36. preprocess_collapse_repeated=False,
  37. ctc_merge_repeated=False)
  38. loss = tf.expand_dims(loss, 1)
  39. if mode == 1:
  40. loss = focal_ctc(sparse_labels, y_pred, input_length, loss)
  41. return loss
  42. def focal_ctc(targets, logits, seq_len, ctc_loss, alpha=0.5, gamma=2.0):
  43. # FOCAL LOSS
  44. # This function computes Focal Loss
  45. # Inputs: alpha, gamma, targets, logits, seq_len
  46. # Default Values: alpha=0.5 and gamma=2.0
  47. # Output: loss
  48. # ctc_loss = tf.compat.v1.nn.ctc_loss(labels=targets, inputs=logits, sequence_length=seq_len, time_major=True)
  49. p = tf.exp(-ctc_loss)
  50. # ((alpha)*((1-p)**gamma)*(ctc_loss))
  51. focal_ctc_loss = tf.multiply(tf.multiply(alpha, tf.pow((1-p), gamma)), ctc_loss)
  52. loss = tf.reduce_mean(focal_ctc_loss)
  53. return loss
  54. def ctc_decode_mse_loss2(args):
  55. batch_size = 32
  56. num_classes = 35+2
  57. time_step = 21
  58. label_len = 8
  59. # [32, 21, 37]
  60. y_pred, labels, input_length, label_length = args
  61. # [32, 21]
  62. y_max = tf.argmax(y_pred, axis=-1, name='raw_prediction', output_type=tf.int32)
  63. # [32, 8]
  64. labels = tf.cast(labels, tf.int32)
  65. # [batch, step]
  66. new_label = tf.zeros((batch_size, time_step), dtype=tf.int32)
  67. @tf.function
  68. def body(_i, _label):
  69. new_sample = tf.zeros((1, time_step), dtype=tf.int32)
  70. for j in range(0, label_len):
  71. # if tf.greater(0, y_max[_i, j]):
  72. find_flag = False
  73. for k in range(0, time_step):
  74. # 循环y_pred,找对应labels,会漏掉
  75. # if k < K.shape(labels)[1] and tf.equal(y_max[_i, j], labels[_i, k]):
  76. # # tf.print("equal", y_max[_i, j], labels[_i, k])
  77. # if j == 0:
  78. # new_sample = tf.concat([labels[_i:_i+1, k:k+1], new_sample[:, j+1:]], axis=-1)
  79. # elif j >= time_step-1:
  80. # new_sample = tf.concat([new_sample[:, :j], labels[_i:_i+1, k:k+1]], axis=-1)
  81. # else:
  82. # new_sample = tf.concat([new_sample[:, :j], labels[_i:_i+1, k:k+1], new_sample[:, j+1:]], axis=-1)
  83. # 循环labels,找对应y_pred,漏掉的找个0位置覆盖
  84. # tf.print("labels", labels[_i], last_k, j, labels[_i].shape, new_sample.shape)
  85. if tf.equal(y_max[_i, k], labels[_i, j]) and tf.not_equal(y_max[_i, k], 0):
  86. find_flag = True
  87. if k == 0:
  88. new_sample = tf.concat([labels[_i:_i+1, j:j+1], new_sample[:, k+1:]], axis=-1)
  89. elif k >= time_step-1:
  90. new_sample = tf.concat([new_sample[:, :k], labels[_i:_i+1, j:j+1]], axis=-1)
  91. else:
  92. new_sample = tf.concat([new_sample[:, :k], labels[_i:_i+1, j:j+1], new_sample[:, k+1:]], axis=-1)
  93. # tf.print("new_sample", new_sample, last_k, j, K.shape(labels[_i]), K.shape(new_sample))
  94. if not find_flag and tf.not_equal(labels[_i, j], 0):
  95. find_flag2 = False
  96. for k in range(0, time_step):
  97. if not find_flag2 and tf.equal(y_max[_i, k], 0) and tf.equal(new_sample[0, k], 0):
  98. find_flag2 = True
  99. if k == 0:
  100. new_sample = tf.concat([labels[_i:_i+1, j:j+1], new_sample[:, k+1:]], axis=-1)
  101. elif k >= time_step-1:
  102. new_sample = tf.concat([new_sample[:, :k], labels[_i:_i+1, j:j+1]], axis=-1)
  103. else:
  104. new_sample = tf.concat([new_sample[:, :k], labels[_i:_i+1, j:j+1], new_sample[:, k+1:]], axis=-1)
  105. # tf.print("new_sample", new_sample, labels[_i, j], find_flag, find_flag2, summarize=100)
  106. # tf.print("new_sample", new_sample, summarize=100)
  107. # tf.print("equal", new_sample, labels[_i], summarize=100)
  108. if _i == 0:
  109. _label = tf.concat([new_sample[:, :], _label[_i+1:, :]], axis=0)
  110. elif _i >= time_step-1:
  111. _label = tf.concat([_label[:_i, :], new_sample[:, :]], axis=0)
  112. else:
  113. _label = tf.concat([_label[:_i, :], new_sample[:, :], _label[_i+1:, :]], axis=0)
  114. _i = tf.add(_i, 1)
  115. return _i, _label
  116. def cond(_i, _label):
  117. return tf.less(_i, K.shape(labels)[0])
  118. i = tf.constant(1, dtype=tf.int32)
  119. _, new_label = tf.while_loop(cond, body, [i, new_label],
  120. shape_invariants=[i.get_shape(), tf.TensorShape([None, None])])
  121. new_label = tf.one_hot(new_label, depth=num_classes, axis=1, dtype=tf.float32)
  122. new_label = tf.compat.v1.transpose(new_label, perm=[0, 2, 1])
  123. # print("y_pred", y_pred.shape)
  124. # print("new_label", new_label.shape)
  125. loss = tf.reduce_mean(tf.abs((new_label-y_pred)), axis=-1)
  126. loss = tf.reduce_mean(loss*1, axis=-1)
  127. loss = tf.expand_dims(loss, -1)
  128. return loss