123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154 |
- import tensorflow as tf
- import keras.backend as K
- def ctc_lambda_func(args):
- """
- 定义ctc损失函数
- 参数:y_pred:预测值,labels:标签,input_length:lstm tiemstep,label_length:标签长度
- """
- y_pred, labels, input_length, label_length = args
- # return K.ctc_batch_cost(labels, y_pred, input_length, label_length)
- return my_ctc_batch_cost(labels, y_pred, input_length, label_length, mode=1)
- def my_ctc_batch_cost(y_true, y_pred, input_length, label_length, mode=0):
- """Runs CTC loss algorithm on each batch element.
- Args:
- y_true: tensor `(samples, max_string_length)`
- containing the truth labels.
- y_pred: tensor `(samples, time_steps, num_categories)`
- containing the prediction, or output of the softmax.
- input_length: tensor `(samples, 1)` containing the sequence length for
- each batch item in `y_pred`.
- label_length: tensor `(samples, 1)` containing the sequence length for
- each batch item in `y_true`.
- Returns:
- Tensor with shape (samples,1) containing the
- CTC loss of each element.
- """
- label_length = tf.cast(
- tf.squeeze(label_length, axis=-1), tf.int32)
- input_length = tf.cast(
- tf.squeeze(input_length, axis=-1), tf.int32)
- sparse_labels = tf.cast(
- K.ctc_label_dense_to_sparse(y_true, label_length), tf.int32)
- y_pred = tf.math.log(tf.compat.v1.transpose(y_pred, perm=[1, 0, 2]) + K.epsilon())
- loss = tf.compat.v1.nn.ctc_loss(inputs=y_pred,
- labels=sparse_labels,
- sequence_length=input_length,
- preprocess_collapse_repeated=False,
- ctc_merge_repeated=False)
- loss = tf.expand_dims(loss, 1)
- if mode == 1:
- loss = focal_ctc(sparse_labels, y_pred, input_length, loss)
- return loss
- def focal_ctc(targets, logits, seq_len, ctc_loss, alpha=0.5, gamma=2.0):
- # FOCAL LOSS
- # This function computes Focal Loss
- # Inputs: alpha, gamma, targets, logits, seq_len
- # Default Values: alpha=0.5 and gamma=2.0
- # Output: loss
- # ctc_loss = tf.compat.v1.nn.ctc_loss(labels=targets, inputs=logits, sequence_length=seq_len, time_major=True)
- p = tf.exp(-ctc_loss)
- # ((alpha)*((1-p)**gamma)*(ctc_loss))
- focal_ctc_loss = tf.multiply(tf.multiply(alpha, tf.pow((1-p), gamma)), ctc_loss)
- loss = tf.reduce_mean(focal_ctc_loss)
- return loss
- def ctc_decode_mse_loss2(args):
- batch_size = 32
- num_classes = 35+2
- time_step = 21
- label_len = 8
- # [32, 21, 37]
- y_pred, labels, input_length, label_length = args
- # [32, 21]
- y_max = tf.argmax(y_pred, axis=-1, name='raw_prediction', output_type=tf.int32)
- # [32, 8]
- labels = tf.cast(labels, tf.int32)
- # [batch, step]
- new_label = tf.zeros((batch_size, time_step), dtype=tf.int32)
- @tf.function
- def body(_i, _label):
- new_sample = tf.zeros((1, time_step), dtype=tf.int32)
- for j in range(0, label_len):
- # if tf.greater(0, y_max[_i, j]):
- find_flag = False
- for k in range(0, time_step):
- # 循环y_pred,找对应labels,会漏掉
- # if k < K.shape(labels)[1] and tf.equal(y_max[_i, j], labels[_i, k]):
- # # tf.print("equal", y_max[_i, j], labels[_i, k])
- # if j == 0:
- # new_sample = tf.concat([labels[_i:_i+1, k:k+1], new_sample[:, j+1:]], axis=-1)
- # elif j >= time_step-1:
- # new_sample = tf.concat([new_sample[:, :j], labels[_i:_i+1, k:k+1]], axis=-1)
- # else:
- # new_sample = tf.concat([new_sample[:, :j], labels[_i:_i+1, k:k+1], new_sample[:, j+1:]], axis=-1)
- # 循环labels,找对应y_pred,漏掉的找个0位置覆盖
- # tf.print("labels", labels[_i], last_k, j, labels[_i].shape, new_sample.shape)
- if tf.equal(y_max[_i, k], labels[_i, j]) and tf.not_equal(y_max[_i, k], 0):
- find_flag = True
- if k == 0:
- new_sample = tf.concat([labels[_i:_i+1, j:j+1], new_sample[:, k+1:]], axis=-1)
- elif k >= time_step-1:
- new_sample = tf.concat([new_sample[:, :k], labels[_i:_i+1, j:j+1]], axis=-1)
- else:
- new_sample = tf.concat([new_sample[:, :k], labels[_i:_i+1, j:j+1], new_sample[:, k+1:]], axis=-1)
- # tf.print("new_sample", new_sample, last_k, j, K.shape(labels[_i]), K.shape(new_sample))
- if not find_flag and tf.not_equal(labels[_i, j], 0):
- find_flag2 = False
- for k in range(0, time_step):
- if not find_flag2 and tf.equal(y_max[_i, k], 0) and tf.equal(new_sample[0, k], 0):
- find_flag2 = True
- if k == 0:
- new_sample = tf.concat([labels[_i:_i+1, j:j+1], new_sample[:, k+1:]], axis=-1)
- elif k >= time_step-1:
- new_sample = tf.concat([new_sample[:, :k], labels[_i:_i+1, j:j+1]], axis=-1)
- else:
- new_sample = tf.concat([new_sample[:, :k], labels[_i:_i+1, j:j+1], new_sample[:, k+1:]], axis=-1)
- # tf.print("new_sample", new_sample, labels[_i, j], find_flag, find_flag2, summarize=100)
- # tf.print("new_sample", new_sample, summarize=100)
- # tf.print("equal", new_sample, labels[_i], summarize=100)
- if _i == 0:
- _label = tf.concat([new_sample[:, :], _label[_i+1:, :]], axis=0)
- elif _i >= time_step-1:
- _label = tf.concat([_label[:_i, :], new_sample[:, :]], axis=0)
- else:
- _label = tf.concat([_label[:_i, :], new_sample[:, :], _label[_i+1:, :]], axis=0)
- _i = tf.add(_i, 1)
- return _i, _label
- def cond(_i, _label):
- return tf.less(_i, K.shape(labels)[0])
- i = tf.constant(1, dtype=tf.int32)
- _, new_label = tf.while_loop(cond, body, [i, new_label],
- shape_invariants=[i.get_shape(), tf.TensorShape([None, None])])
- new_label = tf.one_hot(new_label, depth=num_classes, axis=1, dtype=tf.float32)
- new_label = tf.compat.v1.transpose(new_label, perm=[0, 2, 1])
- # print("y_pred", y_pred.shape)
- # print("new_label", new_label.shape)
- loss = tf.reduce_mean(tf.abs((new_label-y_pred)), axis=-1)
- loss = tf.reduce_mean(loss*1, axis=-1)
- loss = tf.expand_dims(loss, -1)
- return loss
|