import tensorflow as tf import keras.backend as K def ctc_lambda_func(args): """ 定义ctc损失函数 参数:y_pred:预测值,labels:标签,input_length:lstm tiemstep,label_length:标签长度 """ y_pred, labels, input_length, label_length = args # return K.ctc_batch_cost(labels, y_pred, input_length, label_length) return my_ctc_batch_cost(labels, y_pred, input_length, label_length, mode=1) def my_ctc_batch_cost(y_true, y_pred, input_length, label_length, mode=0): """Runs CTC loss algorithm on each batch element. Args: y_true: tensor `(samples, max_string_length)` containing the truth labels. y_pred: tensor `(samples, time_steps, num_categories)` containing the prediction, or output of the softmax. input_length: tensor `(samples, 1)` containing the sequence length for each batch item in `y_pred`. label_length: tensor `(samples, 1)` containing the sequence length for each batch item in `y_true`. Returns: Tensor with shape (samples,1) containing the CTC loss of each element. """ label_length = tf.cast( tf.squeeze(label_length, axis=-1), tf.int32) input_length = tf.cast( tf.squeeze(input_length, axis=-1), tf.int32) sparse_labels = tf.cast( K.ctc_label_dense_to_sparse(y_true, label_length), tf.int32) y_pred = tf.math.log(tf.compat.v1.transpose(y_pred, perm=[1, 0, 2]) + K.epsilon()) loss = tf.compat.v1.nn.ctc_loss(inputs=y_pred, labels=sparse_labels, sequence_length=input_length, preprocess_collapse_repeated=False, ctc_merge_repeated=False) loss = tf.expand_dims(loss, 1) if mode == 1: loss = focal_ctc(sparse_labels, y_pred, input_length, loss) return loss def focal_ctc(targets, logits, seq_len, ctc_loss, alpha=0.5, gamma=2.0): # FOCAL LOSS # This function computes Focal Loss # Inputs: alpha, gamma, targets, logits, seq_len # Default Values: alpha=0.5 and gamma=2.0 # Output: loss # ctc_loss = tf.compat.v1.nn.ctc_loss(labels=targets, inputs=logits, sequence_length=seq_len, time_major=True) p = tf.exp(-ctc_loss) # ((alpha)*((1-p)**gamma)*(ctc_loss)) focal_ctc_loss = tf.multiply(tf.multiply(alpha, tf.pow((1-p), gamma)), ctc_loss) loss = tf.reduce_mean(focal_ctc_loss) return loss def ctc_decode_mse_loss2(args): batch_size = 32 num_classes = 35+2 time_step = 21 label_len = 8 # [32, 21, 37] y_pred, labels, input_length, label_length = args # [32, 21] y_max = tf.argmax(y_pred, axis=-1, name='raw_prediction', output_type=tf.int32) # [32, 8] labels = tf.cast(labels, tf.int32) # [batch, step] new_label = tf.zeros((batch_size, time_step), dtype=tf.int32) @tf.function def body(_i, _label): new_sample = tf.zeros((1, time_step), dtype=tf.int32) for j in range(0, label_len): # if tf.greater(0, y_max[_i, j]): find_flag = False for k in range(0, time_step): # 循环y_pred,找对应labels,会漏掉 # if k < K.shape(labels)[1] and tf.equal(y_max[_i, j], labels[_i, k]): # # tf.print("equal", y_max[_i, j], labels[_i, k]) # if j == 0: # new_sample = tf.concat([labels[_i:_i+1, k:k+1], new_sample[:, j+1:]], axis=-1) # elif j >= time_step-1: # new_sample = tf.concat([new_sample[:, :j], labels[_i:_i+1, k:k+1]], axis=-1) # else: # new_sample = tf.concat([new_sample[:, :j], labels[_i:_i+1, k:k+1], new_sample[:, j+1:]], axis=-1) # 循环labels,找对应y_pred,漏掉的找个0位置覆盖 # tf.print("labels", labels[_i], last_k, j, labels[_i].shape, new_sample.shape) if tf.equal(y_max[_i, k], labels[_i, j]) and tf.not_equal(y_max[_i, k], 0): find_flag = True if k == 0: new_sample = tf.concat([labels[_i:_i+1, j:j+1], new_sample[:, k+1:]], axis=-1) elif k >= time_step-1: new_sample = tf.concat([new_sample[:, :k], labels[_i:_i+1, j:j+1]], axis=-1) else: new_sample = tf.concat([new_sample[:, :k], labels[_i:_i+1, j:j+1], new_sample[:, k+1:]], axis=-1) # tf.print("new_sample", new_sample, last_k, j, K.shape(labels[_i]), K.shape(new_sample)) if not find_flag and tf.not_equal(labels[_i, j], 0): find_flag2 = False for k in range(0, time_step): if not find_flag2 and tf.equal(y_max[_i, k], 0) and tf.equal(new_sample[0, k], 0): find_flag2 = True if k == 0: new_sample = tf.concat([labels[_i:_i+1, j:j+1], new_sample[:, k+1:]], axis=-1) elif k >= time_step-1: new_sample = tf.concat([new_sample[:, :k], labels[_i:_i+1, j:j+1]], axis=-1) else: new_sample = tf.concat([new_sample[:, :k], labels[_i:_i+1, j:j+1], new_sample[:, k+1:]], axis=-1) # tf.print("new_sample", new_sample, labels[_i, j], find_flag, find_flag2, summarize=100) # tf.print("new_sample", new_sample, summarize=100) # tf.print("equal", new_sample, labels[_i], summarize=100) if _i == 0: _label = tf.concat([new_sample[:, :], _label[_i+1:, :]], axis=0) elif _i >= time_step-1: _label = tf.concat([_label[:_i, :], new_sample[:, :]], axis=0) else: _label = tf.concat([_label[:_i, :], new_sample[:, :], _label[_i+1:, :]], axis=0) _i = tf.add(_i, 1) return _i, _label def cond(_i, _label): return tf.less(_i, K.shape(labels)[0]) i = tf.constant(1, dtype=tf.int32) _, new_label = tf.while_loop(cond, body, [i, new_label], shape_invariants=[i.get_shape(), tf.TensorShape([None, None])]) new_label = tf.one_hot(new_label, depth=num_classes, axis=1, dtype=tf.float32) new_label = tf.compat.v1.transpose(new_label, perm=[0, 2, 1]) # print("y_pred", y_pred.shape) # print("new_label", new_label.shape) loss = tf.reduce_mean(tf.abs((new_label-y_pred)), axis=-1) loss = tf.reduce_mean(loss*1, axis=-1) loss = tf.expand_dims(loss, -1) return loss