def table_post_process(table_text_list, predict_result, threshold=0.5, is_list=False): if not is_list: predict_result = predict_result.tolist() predict_list = [] for i in range(0, len(predict_result)): predict = predict_result[i][0] if predict >= threshold: predict = "1" else: predict = "0" predict_list.append(predict) table_label_list = [] i = 0 for row in table_text_list: row_label = [] for col in row: row_label.append(predict_list[i]) i += 1 table_label_list.append(row_label) if len(table_label_list) != len(table_text_list): print("table_post_process 输出label维度与text不一致!") table_label_list = [] return table_label_list def table_post_process_2(table_text_list, predict_result, threshold=0.5): predict_result = predict_result.tolist()[0] predict_list = [] for row in predict_result: new_row = [] for col in row: if col >= threshold: new_row.append("1") else: new_row.append("0") predict_list.append(new_row) if len(predict_list) != len(predict_result): print("table_post_process 输出label维度与text不一致!") predict_list = [] return predict_list