123456789101112131415161718192021222324252627282930313233343536373839404142434445 |
- def table_post_process(table_text_list, predict_result, threshold=0.5, is_list=False):
- if not is_list:
- predict_result = predict_result.tolist()
- predict_list = []
- for i in range(0, len(predict_result)):
- predict = predict_result[i][0]
- if predict >= threshold:
- predict = "1"
- else:
- predict = "0"
- predict_list.append(predict)
- table_label_list = []
- i = 0
- for row in table_text_list:
- row_label = []
- for col in row:
- row_label.append(predict_list[i])
- i += 1
- table_label_list.append(row_label)
- if len(table_label_list) != len(table_text_list):
- print("table_post_process 输出label维度与text不一致!")
- table_label_list = []
- return table_label_list
- def table_post_process_2(table_text_list, predict_result, threshold=0.5):
- predict_result = predict_result.tolist()[0]
- predict_list = []
- for row in predict_result:
- new_row = []
- for col in row:
- if col >= threshold:
- new_row.append("1")
- else:
- new_row.append("0")
- predict_list.append(new_row)
- if len(predict_list) != len(predict_result):
- print("table_post_process 输出label维度与text不一致!")
- predict_list = []
- return predict_list
|