predict.py 3.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108
  1. import codecs
  2. import re
  3. from bs4 import BeautifulSoup
  4. from BiddingKG.dl.table_head.models.model import get_model
  5. from BiddingKG.dl.table_head.post_process import table_post_process
  6. from BiddingKG.dl.table_head.pre_process import my_data_loader, table_pre_process
  7. from BiddingKG.dl.interface.Preprocessing import tableToText, segment
  8. # init model
  9. input_shape = (6, 10, 60)
  10. output_shape = (1,)
  11. model = get_model(input_shape, output_shape)
  12. # load weights
  13. model_path = "checkpoints/best.hdf5"
  14. model.load_weights(model_path)
  15. def predict(table_text_list):
  16. # 表格单元格数*2 即为单次预测batch_size
  17. data_list = table_pre_process(table_text_list, [], 0, is_train=False)
  18. batch_size = len(data_list)
  19. # print("batch_size", batch_size)
  20. # 数据预处理
  21. predict_x = my_data_loader(data_list, [], batch_size, is_train=False)
  22. # 预测
  23. predict_result = model.predict_generator(predict_x, steps=1)
  24. # print("predict_result", predict_result.shape)
  25. # 数据后处理
  26. table_label_list = table_post_process(table_text_list, predict_result)
  27. return table_label_list
  28. def predict_html():
  29. def get_trs(tbody):
  30. #获取所有的tr
  31. trs = []
  32. objs = tbody.find_all(recursive=False)
  33. for obj in objs:
  34. if obj.name=="tr":
  35. trs.append(obj)
  36. if obj.name=="tbody":
  37. for tr in obj.find_all("tr",recursive=False):
  38. trs.append(tr)
  39. return trs
  40. def get_table(tbody):
  41. trs = get_trs(tbody)
  42. inner_table = []
  43. for tr in trs:
  44. tr_line = []
  45. tds = tr.findChildren(['td', 'th'], recursive=False)
  46. if len(tds) == 0:
  47. tr_line.append(re.sub('\xa0', '', segment(tr,final=False))) # 2021/12/21 修复部分表格没有td 造成数据丢失
  48. for td in tds:
  49. tr_line.append(re.sub('\xa0', '', segment(td,final=False)))
  50. inner_table.append(tr_line)
  51. return inner_table
  52. def fix_table(inner_table, fix_value=""):
  53. maxWidth = 0
  54. for item in inner_table:
  55. if len(item)>maxWidth:
  56. maxWidth = len(item)
  57. for i in range(len(inner_table)):
  58. if len(inner_table[i])<maxWidth:
  59. for j in range(maxWidth-len(inner_table[i])):
  60. inner_table[i].append(fix_value)
  61. return inner_table
  62. text = codecs.open("C:\\Users\\\Administrator\\Desktop\\2.html","r",encoding="utf8").read()
  63. content = str(BeautifulSoup(text).find("div",id="pcontent"))
  64. soup = BeautifulSoup(content, 'lxml')
  65. table_list = []
  66. tbodies = soup.find_all('tbody')
  67. for tbody_index in range(1,len(tbodies)+1):
  68. tbody = tbodies[len(tbodies)-tbody_index]
  69. table_list.append(tbody)
  70. table_fix_list = []
  71. for tbody in table_list:
  72. inner_table = get_table(tbody)
  73. inner_table = fix_table(inner_table)
  74. table_fix_list.append(inner_table)
  75. for table in table_fix_list:
  76. print("="*30)
  77. print(table)
  78. print(predict(table))
  79. if __name__ == '__main__':
  80. # _str = "[['序号', '投标人名称', '价格得分', '技术得分', '商务得分', '综合得分', '排名'], " \
  81. # "['序号', '投标人名称', '比例(20%),', '比例(45%),', '比例(35%),', '100%', '排名'], " \
  82. # "['1', '广州中科雅图信息技术有限公司', '19.71', '11.50', '11.00', '42.21', '3'], " \
  83. # "['2', '核工业赣州工程勘察院', '19.64', '15.00', '11.00', '45.64', '2'], " \
  84. # "['3', '广东晟腾地信科技有限公司', '20.00', '16.17', '14.00', '50.17', '1']]"
  85. #
  86. # data_list = eval(_str)
  87. # print("len(data_list)", len(data_list))
  88. # predict(data_list)
  89. predict_html()