def table_post_process(table_text_list, predict_result, threshold=0.5): predict_result = predict_result.tolist() predict_list = [] for i in range(0, len(predict_result), 2): row_predict = predict_result[i][1] col_predict = predict_result[i+1][1] if row_predict + col_predict >= 2 * threshold: predict = "1" else: predict = "0" predict_list.append(predict) table_label_list = [] i = 0 for row in table_text_list: row_label = [] for col in row: row_label.append(predict_list[i]) i += 1 table_label_list.append(row_label) if len(table_label_list) != len(table_text_list): print("table_post_process 输出label维度与text不一致!") table_label_list = [] return table_label_list