channel_predictor.py 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406
  1. #!/usr/bin/python3
  2. # -*- coding: utf-8 -*-
  3. # @Author : bidikeji
  4. # @Time : 2021/6/10 0010 14:23
  5. import BiddingKG.dl.interface.Preprocessing as Preprocessing
  6. from BiddingKG.dl.common.Utils import getVocabAndMatrix,getModel_w2v,precision, recall, f1_score
  7. import numpy as np
  8. import pandas as pd
  9. import copy
  10. import tensorflow as tf
  11. import fool
  12. import re
  13. import os
  14. import time
  15. word_model = getModel_w2v()
  16. vocab, embedding_matrix = getVocabAndMatrix(word_model, Embedding_size=128)
  17. word_index = {k:v for v,k in enumerate(vocab)}
  18. height, width = embedding_matrix.shape
  19. sequen_len = 200#150 200
  20. title_len = 30
  21. sentence_num = 10
  22. kws = '供货商|候选人|供应商|入选人|项目|选定|预告|中标|成交|补遗|延期|报名|暂缓|结果|意向|出租|补充|合同|限价|比选|指定|工程|废标|取消|中止|流标|资质|资格|地块|招标|采购|货物|租赁|计划|宗地|需求|来源|土地|澄清|失败|探矿|预审|变更|变卖|遴选|撤销|意见|恢复|采矿|更正|终止|废置|报建|流拍|供地|登记|挂牌|答疑|中选|受让|拍卖|竞拍|审查|入围|更改|条件|洽谈|乙方|后审|控制|暂停|用地|询价|预'
  23. class DocChannel():
  24. def __init__(self, life_model='/model/channel.pb', type_model='/model/doctype.pb'):
  25. self.lift_sess, self.lift_title, self.lift_content, self.lift_prob, self.lift_softmax,\
  26. self.mask, self.mask_title = self.load_life(life_model)
  27. self.type_sess, self.type_title, self.type_content, self.type_prob, self.type_softmax,\
  28. self.type_mask, self.type_mask_title = self.load_type(type_model)
  29. lb_type = ['采招数据', '土地矿产', '拍卖出让', '产权交易', '新闻资讯']
  30. lb_life = ['采购意向', '招标预告', '招标公告', '招标答疑', '公告变更', '资审结果', '中标信息', '合同公告', '废标公告']
  31. self.id2type = {k: v for k, v in enumerate(lb_type)}
  32. self.id2life = {k: v for k, v in enumerate(lb_life)}
  33. def load_life(self,life_model):
  34. with tf.Graph().as_default() as graph:
  35. output_graph_def = graph.as_graph_def()
  36. with open(os.path.dirname(__file__)+life_model, 'rb') as f:
  37. output_graph_def.ParseFromString(f.read())
  38. tf.import_graph_def(output_graph_def, name='')
  39. print("%d ops in the final graph" % len(output_graph_def.node))
  40. del output_graph_def
  41. sess = tf.Session(graph=graph)
  42. sess.run(tf.global_variables_initializer())
  43. inputs = sess.graph.get_tensor_by_name('inputs/inputs:0')
  44. prob = sess.graph.get_tensor_by_name('inputs/dropout:0')
  45. title = sess.graph.get_tensor_by_name('inputs/title:0')
  46. mask = sess.graph.get_tensor_by_name('inputs/mask:0')
  47. mask_title = sess.graph.get_tensor_by_name('inputs/mask_title:0')
  48. # logit = sess.graph.get_tensor_by_name('output/logit:0')
  49. softmax = sess.graph.get_tensor_by_name('output/softmax:0')
  50. return sess, title, inputs, prob, softmax, mask, mask_title
  51. def load_type(self,type_model):
  52. with tf.Graph().as_default() as graph:
  53. output_graph_def = graph.as_graph_def()
  54. with open(os.path.dirname(__file__)+type_model, 'rb') as f:
  55. output_graph_def.ParseFromString(f.read())
  56. tf.import_graph_def(output_graph_def, name='')
  57. print("%d ops in the final graph" % len(output_graph_def.node))
  58. del output_graph_def
  59. sess = tf.Session(graph=graph)
  60. sess.run(tf.global_variables_initializer())
  61. inputs = sess.graph.get_tensor_by_name('inputs/inputs:0')
  62. prob = sess.graph.get_tensor_by_name('inputs/dropout:0')
  63. title = sess.graph.get_tensor_by_name('inputs/title:0')
  64. mask = sess.graph.get_tensor_by_name('inputs/mask:0')
  65. mask_title = sess.graph.get_tensor_by_name('inputs/mask_title:0')
  66. # logit = sess.graph.get_tensor_by_name('output/logit:0')
  67. softmax = sess.graph.get_tensor_by_name('output/softmax:0')
  68. return sess, title, inputs, prob, softmax, mask, mask_title
  69. def predict_process_backup(self, docid='', doctitle='', dochtmlcon=''):
  70. # print('准备预处理')
  71. def get_kw_senten(s, span=10):
  72. doc_sens = []
  73. tmp = 0
  74. num = 0
  75. end_idx = 0
  76. for it in re.finditer(kws, s): # '|'.join(keywordset)
  77. left = s[end_idx:it.end()].split()
  78. right = s[it.end():].split()
  79. tmp_seg = s[tmp:it.start()].split()
  80. if len(tmp_seg) > span or tmp == 0:
  81. doc_sens.append(' '.join(left[-span:] + right[:span]))
  82. end_idx = it.end() + 1 + len(' '.join(right[:span]))
  83. tmp = it.end()
  84. num += 1
  85. if num >= sentence_num:
  86. break
  87. if doc_sens == []:
  88. doc_sens.append(s)
  89. return doc_sens
  90. def word2id(wordlist, max_len=sequen_len):
  91. ids = [word_index.get(w, 0) for w in wordlist]
  92. ids = ids[:max_len] if len(ids) >= max_len else ids + [0] * (max_len - len(ids))
  93. assert len(ids) == max_len
  94. return ids
  95. cost_time = dict()
  96. datas = []
  97. datas_title = []
  98. # articles = [[docid, dochtmlcon, '', '', doctitle]]
  99. try:
  100. # list_articles = Preprocessing.get_preprocessed_article(articles, cost_time)
  101. # list_sentences = Preprocessing.get_preprocessed_sentences(list_articles, True, cost_time)
  102. # sen_words = [sen.tokens for sen in list_sentences[0]]
  103. # words = [it for sen in sen_words for it in sen]
  104. # segword_content = ' '.join(words)
  105. # segword_title = ' '.join(fool.cut(doctitle)[0])
  106. segword_content = dochtmlcon
  107. segword_title = doctitle
  108. except:
  109. segword_content = ''
  110. segword_title = ''
  111. segword_title = ' '.join([it for it in segword_title.split() if it.isalpha() and it in vocab][:title_len])
  112. segword_content = ' '.join([it for it in segword_content.split() if it.isalpha() and it in vocab][:2000])
  113. segword_content = segword_content.replace(' 中 选 ', ' 中选 ').replace(' 中 标 ', ' 中标 ').replace(' 补 遗 ', ' 补遗 '). \
  114. replace(' 更 多', '').replace(' 更多', '').replace(' 中 号 ', ' 中标 ').replace(' 中 选人 ', ' 中选人 '). \
  115. replace(' 点击 下载 查看', '').replace(' 咨询 报价 请 点击', '').replace('终结', '终止')
  116. doc_word_list = segword_content.split()
  117. if len(doc_word_list) > sequen_len / 2:
  118. doc_sens = get_kw_senten(' '.join(doc_word_list[100:500]))
  119. doc_sens = ' '.join(doc_word_list[:100]) + '\n' + '\n'.join(doc_sens)
  120. else:
  121. doc_sens = ' '.join(doc_word_list[:sequen_len])
  122. datas.append(word2id(doc_sens.split(), max_len=sequen_len))
  123. datas_title.append(word2id(segword_title.split(), max_len=title_len))
  124. # print('完成预处理')
  125. return datas, datas_title
  126. def predict_process(self, docid='', doctitle='', dochtmlcon=''):
  127. # print('准备预处理')
  128. def get_kw_senten(s, span=10):
  129. doc_sens = []
  130. tmp = 0
  131. num = 0
  132. end_idx = 0
  133. for it in re.finditer(kws, s): # '|'.join(keywordset)
  134. left = s[end_idx:it.end()].split()
  135. right = s[it.end():].split()
  136. tmp_seg = s[tmp:it.start()].split()
  137. if len(tmp_seg) > span or tmp == 0:
  138. doc_sens.append(' '.join(left[-span:] + right[:span]))
  139. end_idx = it.end() + 1 + len(' '.join(right[:span]))
  140. tmp = it.end()
  141. num += 1
  142. if num >= sentence_num:
  143. break
  144. if doc_sens == []:
  145. doc_sens.append(s)
  146. return doc_sens
  147. def word2id(wordlist, max_len=sequen_len):
  148. ids = [word_index.get(w, 0) for w in wordlist]
  149. ids = ids[:max_len] if len(ids) >= max_len else ids + [0] * (max_len - len(ids))
  150. assert len(ids) == max_len
  151. return ids
  152. cost_time = dict()
  153. datas = []
  154. datas_title = []
  155. # articles = [[docid, dochtmlcon, '', '', doctitle]]
  156. try:
  157. # list_articles = Preprocessing.get_preprocessed_article(articles, cost_time)
  158. # list_sentences = Preprocessing.get_preprocessed_sentences(list_articles, True, cost_time)
  159. # sen_words = [sen.tokens for sen in list_sentences[0]]
  160. # words = [it for sen in sen_words for it in sen]
  161. # segword_content = ' '.join(words)
  162. segword_title = ' '.join(fool.cut(doctitle)[0])
  163. segword_content = dochtmlcon
  164. # segword_title = doctitle
  165. except:
  166. segword_content = ''
  167. segword_title = ''
  168. if isinstance(segword_content, float):
  169. segword_content = ''
  170. if isinstance(segword_title, float):
  171. segword_title = ''
  172. segword_content = segword_content.replace(' 中 选 ', ' 中选 ').replace(' 中 标 ', ' 中标 ').replace(' 补 遗 ', ' 补遗 '). \
  173. replace(' 更 多', '').replace(' 更多', '').replace(' 中 号 ', ' 中标 ').replace(' 中 选人 ', ' 中选人 '). \
  174. replace(' 点击 下载 查看', '').replace(' 咨询 报价 请 点击', '').replace('终结', '终止')
  175. segword_title = re.sub('[^\s\u4e00-\u9fa5]', '', segword_title)
  176. segword_content = re.sub('[^\s\u4e00-\u9fa5]', '', segword_content)
  177. doc_word_list = segword_content.split()
  178. if len(doc_word_list) > sequen_len / 2:
  179. doc_sens = get_kw_senten(' '.join(doc_word_list[100:500]))
  180. doc_sens = ' '.join(doc_word_list[:100]) + '\n' + '\n'.join(doc_sens)
  181. else:
  182. doc_sens = ' '.join(doc_word_list[:sequen_len])
  183. datas.append(word2id(doc_sens.split(), max_len=sequen_len))
  184. datas_title.append(word2id(segword_title.split(), max_len=title_len))
  185. # print('完成预处理')
  186. return datas, datas_title
  187. def is_houxuan(self, title, content):
  188. '''
  189. 通过标题和中文内容判断是否属于候选人公示类别
  190. :param title: 公告标题
  191. :param content: 公告正文文本内容
  192. :return: 1 是候选人公示 ;0 不是
  193. '''
  194. if re.search('候选人的?公示|评标结果|评审结果|中标公示', title): # (中标|成交|中选|入围)
  195. if re.search('变更公告|更正公告|废标|终止|答疑|澄清', title):
  196. return 0
  197. return 1
  198. if re.search('候选人的?公示', content[:100]):
  199. if re.search('公示(期|活动)?已经?结束|公示期已满|中标结果公告|中标结果公示|变更公告|更正公告|废标|终止|答疑|澄清', content[:100]):
  200. return 0
  201. return 1
  202. else:
  203. return 0
  204. def predict(self, title='', content=''):
  205. # print('准备预测')
  206. if isinstance(content, list):
  207. token_l = [it.tokens for it in content]
  208. tokens = [it for l in token_l for it in l]
  209. content = ' '.join(tokens)
  210. data_content, data_title = self.predict_process(docid='', doctitle=title, dochtmlcon=content)
  211. pred = self.type_sess.run(self.type_softmax,
  212. feed_dict={self.type_title:[[embedding_matrix[i] for i in l] for l in data_title],
  213. self.type_content:[[embedding_matrix[i] for i in l] for l in data_content],
  214. self.type_mask:1 - np.not_equal(data_content, 0),
  215. self.type_mask_title:1 - np.not_equal(data_title, 0),
  216. self.type_prob:1}
  217. )
  218. id = np.argmax(pred, axis=1)[0]
  219. prob = pred[0][id]
  220. if id == 0:
  221. pred = self.lift_sess.run(self.lift_softmax,
  222. feed_dict={self.lift_title:[[embedding_matrix[i] for i in l] for l in data_title],
  223. self.lift_content:[[embedding_matrix[i] for i in l] for l in data_content],
  224. self.mask:1 - np.not_equal(data_content, 0),
  225. self.mask_title:1 - np.not_equal(data_title, 0),
  226. self.lift_prob:1}
  227. )
  228. id = np.argmax(pred, axis=1)[0]
  229. prob = pred[0][id]
  230. if id == 6:
  231. if self.is_houxuan(''.join([it for it in title if it.isalpha()]), ''.join([it for it in content if it.isalpha()])):
  232. # return '候选人公示', prob
  233. return [{'docchannel': '候选人公示'}]
  234. # return self.id2life[id], prob
  235. return [{'docchannel':self.id2life[id]}]
  236. else:
  237. # return self.id2type[id], prob
  238. return [{'docchannel':self.id2type[id]}]
  239. def predict_batch(self, title_content_list):
  240. # print('准备预测')
  241. data_content = []
  242. data_title = []
  243. n = 0
  244. t0 = time.time()
  245. for docid, title, content in title_content_list:
  246. data_c , data_t = self.predict_process(docid=docid, doctitle=title, dochtmlcon=content)
  247. print('完成文章处理:%d'%docid)
  248. data_content.append(data_c[0])
  249. data_title.append(data_t[0])
  250. n += 1
  251. if n%1024==0:
  252. print('已完成%d篇文章预处理'%n)
  253. t1 = time.time()
  254. print('文章数:%d,预处理耗时:%.4f'%(len(title_content_list), t1-t0))
  255. bz = 2048
  256. tt_n = int((len(data_content)-1)/bz+1)
  257. types = []
  258. lifts = []
  259. for i in range(tt_n):
  260. pred = self.type_sess.run(self.type_softmax,
  261. feed_dict={self.type_title:[[embedding_matrix[i] for i in l] for l in data_title[i*bz:(i+1)*bz]],
  262. self.type_content:[[embedding_matrix[i] for i in l] for l in data_content[i*bz:(i+1)*bz]],
  263. self.type_mask:1 - np.not_equal(data_content[i*bz:(i+1)*bz], 0),
  264. self.type_mask_title:1 - np.not_equal(data_title[i*bz:(i+1)*bz], 0),
  265. self.type_prob:1}
  266. )
  267. # type_ids = np.argmax(pred, axis=1)
  268. types.extend(pred)
  269. lift_pred = self.lift_sess.run(self.lift_softmax,
  270. feed_dict={self.lift_title:[[embedding_matrix[i] for i in l] for l in data_title[i*bz:(i+1)*bz]],
  271. self.lift_content:[[embedding_matrix[i] for i in l] for l in data_content[i*bz:(i+1)*bz]],
  272. self.mask:1 - np.not_equal(data_content[i*bz:(i+1)*bz], 0),
  273. self.mask_title:1 - np.not_equal(data_title[i*bz:(i+1)*bz], 0),
  274. self.lift_prob:1}
  275. )
  276. # lift_ids = np.argmax(lift_pred, axis=1)
  277. lifts.extend(lift_pred)
  278. print('完成第%d批数据'%i)
  279. preds = []
  280. probs = []
  281. for type, lift in zip(types, lifts):
  282. id = np.argmax(type)
  283. if id == 0:
  284. id = np.argmax(lift)
  285. preds.append(self.id2life[id])
  286. probs.append(lift[id])
  287. else:
  288. preds.append(self.id2type[id])
  289. probs.append(type[id])
  290. t2 = time.time()
  291. print('预测耗时%.4f'%(t2-t1))
  292. return preds, probs
  293. # def channel_predict(df_path):
  294. # df_test = pd.read_excel(df_path)
  295. # df_test.reset_index(drop=True, inplace=True)
  296. # preds = []
  297. # probs = []
  298. # for i in range(len(df_test)):
  299. # # title = df_test.loc[i, 'doctitle']
  300. # # content = df_test.loc[i, 'dochtmlcon']
  301. # title = df_test.loc[i, 'segword_title']
  302. # content = df_test.loc[i, 'segword']
  303. # pred, prob = DocChannel.predict(title, content)
  304. # preds.append(pred)
  305. # probs.append(prob)
  306. # # print(pred, title)
  307. # # label = df_test.loc[i, 'label']
  308. # # if pred != label:
  309. # # print('预测类别:%s, 阈值:%.4f, 标注类别:%s, 标题:%s'
  310. # # % (pred, prob, label, title))
  311. # df_test['pred_new'] = pd.Series(preds)
  312. # df_test['pred_prob'] = pd.Series(probs)
  313. # # df_test.to_excel(df_path[:-5]+'_predict.xlsx')
  314. # df_test.to_excel(df_path)
  315. def is_houxuan(title, content):
  316. '''
  317. 通过标题和中文内容判断是否属于候选人公示类别
  318. :param title: 公告标题
  319. :param content: 公告正文文本内容
  320. :return: 1 是候选人公示 ;0 不是
  321. '''
  322. if re.search('候选人的?公示|评标结果|评审结果|中标公示', title): # (中标|成交|中选|入围)
  323. if re.search('变更公告|更正公告|废标|终止|答疑|澄清', title):
  324. return 0
  325. return 1
  326. if re.search('候选人的?公示', content[:100]):
  327. if re.search('公示(期|活动)?已经?结束|公示期已满|中标结果公告|中标结果公示|变更公告|更正公告|废标|终止|答疑|澄清', content[:100]):
  328. return 0
  329. return 1
  330. else:
  331. return 0
  332. def channel_predict_batch(df_path):
  333. print('批量预测')
  334. df = pd.read_excel(df_path)
  335. df.fillna('', inplace=True)
  336. df.reset_index(drop=True, inplace=True)
  337. bz = 1024*10*6
  338. total_batch = int((len(df)-1)/bz+1)
  339. for i in range(total_batch):
  340. df_test = copy.deepcopy(df[i*bz:(i+1)*bz])
  341. df_test.reset_index(drop=True, inplace=True)
  342. docs = [[docid, title, content] for docid, title, content in zip(df_test['docid'], df_test['segword_title'], df_test['segword'])]
  343. print('总共%d篇文章'%len(docs))
  344. preds, probs = DocChannel.predict_batch(docs)
  345. # df_test['pred_old'] = df_test['pred_new']
  346. df_test['pred_new'] = pd.Series(preds)
  347. df_test['pred_prob'] = pd.Series(probs)
  348. # df_test['old=new'] = df_test.apply(lambda x:1 if x['pred_old']==x['pred_new'] else 0, axis=1)
  349. # df_test = df_test[df_test.loc[:, 'old=new']==0]
  350. # print(df_test.head(3))
  351. # for idx in df_test.index:
  352. # title = df_test.loc[idx, 'doctitle']
  353. # text = re.sub('[^\u4e00-\u9fa5]', '',df_test.loc[idx, 'segword'])
  354. # try:
  355. # if is_houxuan(title, text)==1:
  356. # df_test.loc[idx, 'pred_new'] = '候选人公示'
  357. # except:
  358. # print('出错了',df_test.loc[idx, 'pred_new'],text)
  359. df_test['pred_new'] = df_test.apply(lambda x:'候选人公示' if x['pred_new']=='中标信息' and is_houxuan(x['doctitle'], re.sub('[^\u4e00-\u9fa5]', '', x['segword']))==1 else x['pred_new'] , axis=1)
  360. df_test.to_excel(df_path[:-5]+'_predict_new_{}.xlsx'.format(i))
  361. print('保存文件成功')
  362. if __name__ == "__main__":
  363. path = 'data/候选人公示.xlsx'
  364. DocChannel = DocChannel()
  365. # channel_predict_batch(path)
  366. for path in ['data/docchannel带数据源2021-04-12_bidi_process.xlsx',
  367. 'data/docchannel带数据源2021-04-13_bidi_process.xlsx',
  368. 'data/docchannel带数据源2021-04-14_bidi_process.xlsx',
  369. 'data/docchannel带数据源2021-04-15_bidi_process.xlsx',
  370. 'data/docchannel带数据源2021-04-16_bidi_process.xlsx']:
  371. # for path in ['data/docchannel带数据源2021-04-12_bidi_process_predict_0.xlsx',
  372. # 'data/docchannel带数据源2021-04-13_bidi_process_predict_0.xlsx',
  373. # # 'data/docchannel带数据源2021-04-14_bidi_process.xlsx',
  374. # 'data/docchannel带数据源2021-04-15_bidi_process_predict_0.xlsx',
  375. # 'data/docchannel带数据源2021-04-16_bidi_process_predict_0.xlsx']:
  376. channel_predict_batch(path)
  377. # df_test = pd.read_excel('data/df_test_公告类型.xlsx')