predict.py 1.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142
  1. from BiddingKG.dl.table_head.models.model import get_model
  2. from BiddingKG.dl.table_head.post_process import table_post_process
  3. from BiddingKG.dl.table_head.pre_process import my_data_loader, table_pre_process
  4. # init model
  5. input_shape = (3, 10, 60)
  6. output_shape = (3,)
  7. model = get_model(input_shape, output_shape)
  8. # load weights
  9. model_path = "checkpoints/best.hdf5"
  10. model.load_weights(model_path)
  11. def predict(table_text_list):
  12. # 表格单元格数*2 即为单次预测batch_size
  13. data_list = table_pre_process(table_text_list, [], 0, is_train=False)
  14. batch_size = len(data_list)
  15. # print("batch_size", batch_size)
  16. # 数据预处理
  17. predict_x = my_data_loader(data_list, [], batch_size, is_train=False)
  18. # 预测
  19. predict_result = model.predict_generator(predict_x, steps=1)
  20. # print("predict_result", predict_result.shape)
  21. # 数据后处理
  22. table_label_list = table_post_process(table_text_list, predict_result)
  23. return table_label_list
  24. if __name__ == '__main__':
  25. _str = "[['序号', '投标人名称', '价格得分', '技术得分', '商务得分', '综合得分', '排名'], " \
  26. "['序号', '投标人名称', '比例(20%),', '比例(45%),', '比例(35%),', '100%', '排名'], " \
  27. "['1', '广州中科雅图信息技术有限公司', '19.71', '11.50', '11.00', '42.21', '3'], " \
  28. "['2', '核工业赣州工程勘察院', '19.64', '15.00', '11.00', '45.64', '2'], " \
  29. "['3', '广东晟腾地信科技有限公司', '20.00', '16.17', '14.00', '50.17', '1']]"
  30. data_list = eval(_str)
  31. print("len(data_list)", len(data_list))
  32. predict(data_list)