1234567891011121314151617181920212223242526272829303132333435363738394041424344 |
- def table_post_process(table_text_list, predict_result, threshold=0.5):
- predict_result = predict_result.tolist()
- predict_list = []
- for i in range(0, len(predict_result)):
- predict = predict_result[i][0]
- if predict >= threshold:
- predict = "1"
- else:
- predict = "0"
- predict_list.append(predict)
- table_label_list = []
- i = 0
- for row in table_text_list:
- row_label = []
- for col in row:
- row_label.append(predict_list[i])
- i += 1
- table_label_list.append(row_label)
- if len(table_label_list) != len(table_text_list):
- print("table_post_process 输出label维度与text不一致!")
- table_label_list = []
- return table_label_list
- def table_post_process_2(table_text_list, predict_result, threshold=0.5):
- predict_result = predict_result.tolist()[0]
- predict_list = []
- for row in predict_result:
- new_row = []
- for col in row:
- if col >= threshold:
- new_row.append("1")
- else:
- new_row.append("0")
- predict_list.append(new_row)
- if len(predict_list) != len(predict_result):
- print("table_post_process 输出label维度与text不一致!")
- predict_list = []
- return predict_list
|