predict.py 5.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111
  1. import os, sys
  2. import numpy as np
  3. import re
  4. import tensorflow as tf
  5. import jieba
  6. import gensim
  7. maxlen = 512
  8. words_size = 128
  9. w2v_filepath = os.path.dirname(os.path.abspath(__file__))+"/wiki_128_word_embedding_new.vector"
  10. model_w2v = gensim.models.KeyedVectors.load_word2vec_format(w2v_filepath, binary=True)
  11. def get_words_matrix(words):
  12. if words in model_w2v.key_to_index:
  13. return model_w2v[words]
  14. else:
  15. return model_w2v['unk']
  16. class ModelRelationExtraction:
  17. def __init__(self):
  18. self.model_file = os.path.dirname(os.path.abspath(__file__))+"/models/model_attachment_classify"
  19. self.sess = tf.compat.v1.Session(graph=tf.Graph())
  20. self.classes_dict = {
  21. 0: '其他',
  22. 1: '招标文件',
  23. 2: '限价(控制价)',
  24. 3: '工程量清单',
  25. 4: '采购清单',
  26. 5: '评标办法'
  27. }
  28. self.get_model()
  29. def get_model(self):
  30. with self.sess.as_default() as sess:
  31. with sess.graph.as_default():
  32. meta_graph_def = tf.compat.v1.saved_model.loader.load(sess, tags=["serve"], export_dir=self.model_file)
  33. signature_key = tf.compat.v1.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY
  34. signature_def = meta_graph_def.signature_def
  35. input0 = sess.graph.get_tensor_by_name(signature_def[signature_key].inputs["input0"].name)
  36. print(input0.shape)
  37. output = sess.graph.get_tensor_by_name(signature_def[signature_key].outputs["outputs"].name)
  38. self.model = [input0, output]
  39. return self.model
  40. def text_process(self, attachmentcon):
  41. text = attachmentcon
  42. text = re.sub("\n+", ',', text)
  43. text = re.sub("\s+|?+", '', text)
  44. text = re.sub("[\.·_]{2,}", ',', text)
  45. text = re.sub("_", '', text)
  46. text = text[:2500]
  47. tokens = list(jieba.cut(text))
  48. return tokens
  49. def evaluate(self, attachmentcon):
  50. text = str(attachmentcon)
  51. tokens = self.text_process(text)
  52. maxlen = 512
  53. tokens = tokens[:maxlen]
  54. words_matrix = np.zeros((maxlen, words_size))
  55. for i in range(len(tokens)):
  56. words_matrix[i] = np.array(get_words_matrix(tokens[i]))
  57. words_matrix = np.array([words_matrix])
  58. pred = limit_run(self.sess, [self.model[1]], feed_dict={self.model[0]: words_matrix})[0]
  59. pred_label = np.argmax(pred[0])
  60. cn_label = self.classes_dict[pred_label]
  61. return pred_label, cn_label
  62. def limit_run(sess, list_output, feed_dict, max_batch=1024):
  63. len_sample = 0
  64. if len(feed_dict.keys()) > 0:
  65. len_sample = len(feed_dict[list(feed_dict.keys())[0]])
  66. if len_sample > max_batch:
  67. list_result = [[] for _ in range(len(list_output))]
  68. _begin = 0
  69. while _begin < len_sample:
  70. new_dict = dict()
  71. for _key in feed_dict.keys():
  72. if isinstance(feed_dict[_key], (float, int, np.int32, np.float_, np.float16, np.float32, np.float64)):
  73. new_dict[_key] = feed_dict[_key]
  74. else:
  75. new_dict[_key] = feed_dict[_key][_begin:_begin+max_batch]
  76. _output = sess.run(list_output,feed_dict=new_dict)
  77. for _index in range(len(list_output)):
  78. list_result[_index].extend(_output[_index])
  79. _begin += max_batch
  80. else:
  81. list_result = sess.run(list_output, feed_dict=feed_dict)
  82. return list_result
  83. if __name__ == '__main__':
  84. text = '''招标文件项目编号:SDGP370302202102000110项目名称:淄川经济开发区中心小学校园智能化采购项目采购人:山东淄川经
  85. 济开发区管理委员会采购代理机构:淄博正益招标有限公司发出日期:2021年8月目录第一章投标邀请7一、项目基本情况7二、申请人的资格要
  86. 求8三、获取招标文件8四、提交投标文件截止时间、开标时间和地点8五、公告期限9六、其他补充事宜9第二章投标人须知11一、总则161.采
  87. 购人、采购代理机构及投标人162.资金来源183.投标费用184.适用法律18二、招标文件185.招标文件构成186.招标文件的澄清与修改207.投
  88. 标截止时间的顺延20三、投标文件的编制208.编制要求209.投标范围及投标文件中标准和计量单位的使用2110.投标文件构成2211.投标报价241
  89. 2.电子版投标文件2513.投标保证金2614.投标有效期2615.投标文件的签署及规定26四、投标文件的递交2616.投标文件的递交2617.递交
  90. 投标文件的截止时间2718.投标文件的接收、修改与撤回27五、开标及评标2719.开标2720.资格审查2821.组建评标委员会2922.投标文件符
  91. 合性审查与澄清3023.投标偏离3224.投标无效3225.比较和评价3326.废标3527.保密要求36六、确定中标3628.中标候选人的确定原则及标
  92. 准3629.确定中标候选人和中标人3630.采购任务取消3631.中标通知书3632.签订合同3633.履约保证金3734.政府采购融资担保3735.预付
  93. 款3736.廉洁自律规定3737.人员回避3738.质疑与接收3739.项目其他相关费用3940.合同公示3941.验收4042.履约验收公示4043.招标文
  94. 件解释权40第三章货物需求41一、项目概述41
  95. '''
  96. test_text = re.sub('\n', '', text)
  97. model = ModelRelationExtraction()
  98. print(model.evaluate(test_text))