123456789101112131415161718192021222324252627 |
- def table_post_process(table_text_list, predict_result, threshold=0.5):
- predict_result = predict_result.tolist()
- predict_list = []
- for i in range(0, len(predict_result), 2):
- row_predict = predict_result[i][1]
- col_predict = predict_result[i+1][1]
- if row_predict + col_predict >= 2 * threshold:
- predict = "1"
- else:
- predict = "0"
- predict_list.append(predict)
- table_label_list = []
- i = 0
- for row in table_text_list:
- row_label = []
- for col in row:
- row_label.append(predict_list[i])
- i += 1
- table_label_list.append(row_label)
- if len(table_label_list) != len(table_text_list):
- print("table_post_process 输出label维度与text不一致!")
- table_label_list = []
- return table_label_list
|