channel_predictor.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398
  1. #!/usr/bin/python3
  2. # -*- coding: utf-8 -*-
  3. # @Author : bidikeji
  4. # @Time : 2021/6/10 0010 14:23
  5. import BiddingKG.dl.interface.Preprocessing as Preprocessing
  6. from BiddingKG.dl.common.Utils import getVocabAndMatrix,getModel_w2v,precision, recall, f1_score
  7. import numpy as np
  8. import pandas as pd
  9. import copy
  10. import tensorflow as tf
  11. import fool
  12. import re
  13. import time
  14. word_model = getModel_w2v()
  15. vocab, embedding_matrix = getVocabAndMatrix(word_model, Embedding_size=128)
  16. word_index = {k:v for v,k in enumerate(vocab)}
  17. height, width = embedding_matrix.shape
  18. sequen_len = 200#150 200
  19. title_len = 30
  20. sentence_num = 10
  21. kws = '供货商|候选人|供应商|入选人|项目|选定|预告|中标|成交|补遗|延期|报名|暂缓|结果|意向|出租|补充|合同|限价|比选|指定|工程|废标|取消|中止|流标|资质|资格|地块|招标|采购|货物|租赁|计划|宗地|需求|来源|土地|澄清|失败|探矿|预审|变更|变卖|遴选|撤销|意见|恢复|采矿|更正|终止|废置|报建|流拍|供地|登记|挂牌|答疑|中选|受让|拍卖|竞拍|审查|入围|更改|条件|洽谈|乙方|后审|控制|暂停|用地|询价|预'
  22. class DocChannel():
  23. def __init__(self, life_model='model/channel.pb', type_model='model/doctype.pb'):
  24. self.lift_sess, self.lift_title, self.lift_content, self.lift_prob, self.lift_softmax,\
  25. self.mask, self.mask_title = self.load_life(life_model)
  26. self.type_sess, self.type_title, self.type_content, self.type_prob, self.type_softmax,\
  27. self.type_mask, self.type_mask_title = self.load_type(type_model)
  28. lb_type = ['采招数据', '土地矿产', '拍卖出让', '产权交易', '新闻资讯']
  29. lb_life = ['采购意向', '招标预告', '招标公告', '招标答疑', '公告变更', '资审结果', '中标信息', '合同公告', '废标公告']
  30. self.id2type = {k: v for k, v in enumerate(lb_type)}
  31. self.id2life = {k: v for k, v in enumerate(lb_life)}
  32. def load_life(self,life_model):
  33. with tf.Graph().as_default() as graph:
  34. output_graph_def = graph.as_graph_def()
  35. with open(life_model, 'rb') as f:
  36. output_graph_def.ParseFromString(f.read())
  37. tf.import_graph_def(output_graph_def, name='')
  38. print("%d ops in the final graph" % len(output_graph_def.node))
  39. del output_graph_def
  40. sess = tf.Session(graph=graph)
  41. sess.run(tf.global_variables_initializer())
  42. inputs = sess.graph.get_tensor_by_name('inputs/inputs:0')
  43. prob = sess.graph.get_tensor_by_name('inputs/dropout:0')
  44. title = sess.graph.get_tensor_by_name('inputs/title:0')
  45. mask = sess.graph.get_tensor_by_name('inputs/mask:0')
  46. mask_title = sess.graph.get_tensor_by_name('inputs/mask_title:0')
  47. # logit = sess.graph.get_tensor_by_name('output/logit:0')
  48. softmax = sess.graph.get_tensor_by_name('output/softmax:0')
  49. return sess, title, inputs, prob, softmax, mask, mask_title
  50. def load_type(self,type_model):
  51. with tf.Graph().as_default() as graph:
  52. output_graph_def = graph.as_graph_def()
  53. with open(type_model, 'rb') as f:
  54. output_graph_def.ParseFromString(f.read())
  55. tf.import_graph_def(output_graph_def, name='')
  56. print("%d ops in the final graph" % len(output_graph_def.node))
  57. del output_graph_def
  58. sess = tf.Session(graph=graph)
  59. sess.run(tf.global_variables_initializer())
  60. inputs = sess.graph.get_tensor_by_name('inputs/inputs:0')
  61. prob = sess.graph.get_tensor_by_name('inputs/dropout:0')
  62. title = sess.graph.get_tensor_by_name('inputs/title:0')
  63. mask = sess.graph.get_tensor_by_name('inputs/mask:0')
  64. mask_title = sess.graph.get_tensor_by_name('inputs/mask_title:0')
  65. # logit = sess.graph.get_tensor_by_name('output/logit:0')
  66. softmax = sess.graph.get_tensor_by_name('output/softmax:0')
  67. return sess, title, inputs, prob, softmax, mask, mask_title
  68. def predict_process_backup(self, docid='', doctitle='', dochtmlcon=''):
  69. # print('准备预处理')
  70. def get_kw_senten(s, span=10):
  71. doc_sens = []
  72. tmp = 0
  73. num = 0
  74. end_idx = 0
  75. for it in re.finditer(kws, s): # '|'.join(keywordset)
  76. left = s[end_idx:it.end()].split()
  77. right = s[it.end():].split()
  78. tmp_seg = s[tmp:it.start()].split()
  79. if len(tmp_seg) > span or tmp == 0:
  80. doc_sens.append(' '.join(left[-span:] + right[:span]))
  81. end_idx = it.end() + 1 + len(' '.join(right[:span]))
  82. tmp = it.end()
  83. num += 1
  84. if num >= sentence_num:
  85. break
  86. if doc_sens == []:
  87. doc_sens.append(s)
  88. return doc_sens
  89. def word2id(wordlist, max_len=sequen_len):
  90. ids = [word_index.get(w, 0) for w in wordlist]
  91. ids = ids[:max_len] if len(ids) >= max_len else ids + [0] * (max_len - len(ids))
  92. assert len(ids) == max_len
  93. return ids
  94. cost_time = dict()
  95. datas = []
  96. datas_title = []
  97. # articles = [[docid, dochtmlcon, '', '', doctitle]]
  98. try:
  99. # list_articles = Preprocessing.get_preprocessed_article(articles, cost_time)
  100. # list_sentences = Preprocessing.get_preprocessed_sentences(list_articles, True, cost_time)
  101. # sen_words = [sen.tokens for sen in list_sentences[0]]
  102. # words = [it for sen in sen_words for it in sen]
  103. # segword_content = ' '.join(words)
  104. # segword_title = ' '.join(fool.cut(doctitle)[0])
  105. segword_content = dochtmlcon
  106. segword_title = doctitle
  107. except:
  108. segword_content = ''
  109. segword_title = ''
  110. segword_title = ' '.join([it for it in segword_title.split() if it.isalpha() and it in vocab][:title_len])
  111. segword_content = ' '.join([it for it in segword_content.split() if it.isalpha() and it in vocab][:2000])
  112. segword_content = segword_content.replace(' 中 选 ', ' 中选 ').replace(' 中 标 ', ' 中标 ').replace(' 补 遗 ', ' 补遗 '). \
  113. replace(' 更 多', '').replace(' 更多', '').replace(' 中 号 ', ' 中标 ').replace(' 中 选人 ', ' 中选人 '). \
  114. replace(' 点击 下载 查看', '').replace(' 咨询 报价 请 点击', '').replace('终结', '终止')
  115. doc_word_list = segword_content.split()
  116. if len(doc_word_list) > sequen_len / 2:
  117. doc_sens = get_kw_senten(' '.join(doc_word_list[100:500]))
  118. doc_sens = ' '.join(doc_word_list[:100]) + '\n' + '\n'.join(doc_sens)
  119. else:
  120. doc_sens = ' '.join(doc_word_list[:sequen_len])
  121. datas.append(word2id(doc_sens.split(), max_len=sequen_len))
  122. datas_title.append(word2id(segword_title.split(), max_len=title_len))
  123. # print('完成预处理')
  124. return datas, datas_title
  125. def predict_process(self, docid='', doctitle='', dochtmlcon=''):
  126. # print('准备预处理')
  127. def get_kw_senten(s, span=10):
  128. doc_sens = []
  129. tmp = 0
  130. num = 0
  131. end_idx = 0
  132. for it in re.finditer(kws, s): # '|'.join(keywordset)
  133. left = s[end_idx:it.end()].split()
  134. right = s[it.end():].split()
  135. tmp_seg = s[tmp:it.start()].split()
  136. if len(tmp_seg) > span or tmp == 0:
  137. doc_sens.append(' '.join(left[-span:] + right[:span]))
  138. end_idx = it.end() + 1 + len(' '.join(right[:span]))
  139. tmp = it.end()
  140. num += 1
  141. if num >= sentence_num:
  142. break
  143. if doc_sens == []:
  144. doc_sens.append(s)
  145. return doc_sens
  146. def word2id(wordlist, max_len=sequen_len):
  147. ids = [word_index.get(w, 0) for w in wordlist]
  148. ids = ids[:max_len] if len(ids) >= max_len else ids + [0] * (max_len - len(ids))
  149. assert len(ids) == max_len
  150. return ids
  151. cost_time = dict()
  152. datas = []
  153. datas_title = []
  154. # articles = [[docid, dochtmlcon, '', '', doctitle]]
  155. try:
  156. # list_articles = Preprocessing.get_preprocessed_article(articles, cost_time)
  157. # list_sentences = Preprocessing.get_preprocessed_sentences(list_articles, True, cost_time)
  158. # sen_words = [sen.tokens for sen in list_sentences[0]]
  159. # words = [it for sen in sen_words for it in sen]
  160. # segword_content = ' '.join(words)
  161. segword_title = ' '.join(fool.cut(doctitle)[0])
  162. segword_content = dochtmlcon
  163. # segword_title = doctitle
  164. except:
  165. segword_content = ''
  166. segword_title = ''
  167. if isinstance(segword_content, float):
  168. segword_content = ''
  169. if isinstance(segword_title, float):
  170. segword_title = ''
  171. segword_content = segword_content.replace(' 中 选 ', ' 中选 ').replace(' 中 标 ', ' 中标 ').replace(' 补 遗 ', ' 补遗 '). \
  172. replace(' 更 多', '').replace(' 更多', '').replace(' 中 号 ', ' 中标 ').replace(' 中 选人 ', ' 中选人 '). \
  173. replace(' 点击 下载 查看', '').replace(' 咨询 报价 请 点击', '').replace('终结', '终止')
  174. segword_title = re.sub('[^\s\u4e00-\u9fa5]', '', segword_title)
  175. segword_content = re.sub('[^\s\u4e00-\u9fa5]', '', segword_content)
  176. doc_word_list = segword_content.split()
  177. if len(doc_word_list) > sequen_len / 2:
  178. doc_sens = get_kw_senten(' '.join(doc_word_list[100:500]))
  179. doc_sens = ' '.join(doc_word_list[:100]) + '\n' + '\n'.join(doc_sens)
  180. else:
  181. doc_sens = ' '.join(doc_word_list[:sequen_len])
  182. datas.append(word2id(doc_sens.split(), max_len=sequen_len))
  183. datas_title.append(word2id(segword_title.split(), max_len=title_len))
  184. # print('完成预处理')
  185. return datas, datas_title
  186. def is_houxuan(self, title, content):
  187. '''
  188. 通过标题和中文内容判断是否属于候选人公示类别
  189. :param title: 公告标题
  190. :param content: 公告正文文本内容
  191. :return: 1 是候选人公示 ;0 不是
  192. '''
  193. if re.search('候选人的?公示|评标结果|评审结果|中标公示', title): # (中标|成交|中选|入围)
  194. if re.search('变更公告|更正公告|废标|终止|答疑|澄清', title):
  195. return 0
  196. return 1
  197. if re.search('候选人的?公示', content[:100]):
  198. if re.search('公示(期|活动)?已经?结束|公示期已满|中标结果公告|中标结果公示|变更公告|更正公告|废标|终止|答疑|澄清', content[:100]):
  199. return 0
  200. return 1
  201. else:
  202. return 0
  203. def predict(self, title, content):
  204. # print('准备预测')
  205. data_content, data_title = self.predict_process(docid='', doctitle=title, dochtmlcon=content)
  206. pred = self.type_sess.run(self.type_softmax,
  207. feed_dict={self.type_title:[[embedding_matrix[i] for i in l] for l in data_title],
  208. self.type_content:[[embedding_matrix[i] for i in l] for l in data_content],
  209. self.type_mask:1 - np.not_equal(data_content, 0),
  210. self.type_mask_title:1 - np.not_equal(data_title, 0),
  211. self.type_prob:1}
  212. )
  213. id = np.argmax(pred, axis=1)[0]
  214. prob = pred[0][id]
  215. if id == 0:
  216. pred = self.lift_sess.run(self.lift_softmax,
  217. feed_dict={self.lift_title:[[embedding_matrix[i] for i in l] for l in data_title],
  218. self.lift_content:[[embedding_matrix[i] for i in l] for l in data_content],
  219. self.mask:1 - np.not_equal(data_content, 0),
  220. self.mask_title:1 - np.not_equal(data_title, 0),
  221. self.lift_prob:1}
  222. )
  223. id = np.argmax(pred, axis=1)[0]
  224. prob = pred[0][id]
  225. if id == 6:
  226. if self.is_houxuan(''.join([it for it in title if it.isalpha()]), ''.join([it for it in content if it.isalpha()])):
  227. return '候选人公示', prob
  228. return self.id2life[id], prob
  229. else:
  230. return self.id2type[id], prob
  231. def predict_batch(self, title_content_list):
  232. # print('准备预测')
  233. data_content = []
  234. data_title = []
  235. n = 0
  236. t0 = time.time()
  237. for docid, title, content in title_content_list:
  238. data_c , data_t = self.predict_process(docid=docid, doctitle=title, dochtmlcon=content)
  239. print('完成文章处理:%d'%docid)
  240. data_content.append(data_c[0])
  241. data_title.append(data_t[0])
  242. n += 1
  243. if n%1024==0:
  244. print('已完成%d篇文章预处理'%n)
  245. t1 = time.time()
  246. print('文章数:%d,预处理耗时:%.4f'%(len(title_content_list), t1-t0))
  247. bz = 2048
  248. tt_n = int((len(data_content)-1)/bz+1)
  249. types = []
  250. lifts = []
  251. for i in range(tt_n):
  252. pred = self.type_sess.run(self.type_softmax,
  253. feed_dict={self.type_title:[[embedding_matrix[i] for i in l] for l in data_title[i*bz:(i+1)*bz]],
  254. self.type_content:[[embedding_matrix[i] for i in l] for l in data_content[i*bz:(i+1)*bz]],
  255. self.type_mask:1 - np.not_equal(data_content[i*bz:(i+1)*bz], 0),
  256. self.type_mask_title:1 - np.not_equal(data_title[i*bz:(i+1)*bz], 0),
  257. self.type_prob:1}
  258. )
  259. # type_ids = np.argmax(pred, axis=1)
  260. types.extend(pred)
  261. lift_pred = self.lift_sess.run(self.lift_softmax,
  262. feed_dict={self.lift_title:[[embedding_matrix[i] for i in l] for l in data_title[i*bz:(i+1)*bz]],
  263. self.lift_content:[[embedding_matrix[i] for i in l] for l in data_content[i*bz:(i+1)*bz]],
  264. self.mask:1 - np.not_equal(data_content[i*bz:(i+1)*bz], 0),
  265. self.mask_title:1 - np.not_equal(data_title[i*bz:(i+1)*bz], 0),
  266. self.lift_prob:1}
  267. )
  268. # lift_ids = np.argmax(lift_pred, axis=1)
  269. lifts.extend(lift_pred)
  270. print('完成第%d批数据'%i)
  271. preds = []
  272. probs = []
  273. for type, lift in zip(types, lifts):
  274. id = np.argmax(type)
  275. if id == 0:
  276. id = np.argmax(lift)
  277. preds.append(self.id2life[id])
  278. probs.append(lift[id])
  279. else:
  280. preds.append(self.id2type[id])
  281. probs.append(type[id])
  282. t2 = time.time()
  283. print('预测耗时%.4f'%(t2-t1))
  284. return preds, probs
  285. # def channel_predict(df_path):
  286. # df_test = pd.read_excel(df_path)
  287. # df_test.reset_index(drop=True, inplace=True)
  288. # preds = []
  289. # probs = []
  290. # for i in range(len(df_test)):
  291. # # title = df_test.loc[i, 'doctitle']
  292. # # content = df_test.loc[i, 'dochtmlcon']
  293. # title = df_test.loc[i, 'segword_title']
  294. # content = df_test.loc[i, 'segword']
  295. # pred, prob = DocChannel.predict(title, content)
  296. # preds.append(pred)
  297. # probs.append(prob)
  298. # # print(pred, title)
  299. # # label = df_test.loc[i, 'label']
  300. # # if pred != label:
  301. # # print('预测类别:%s, 阈值:%.4f, 标注类别:%s, 标题:%s'
  302. # # % (pred, prob, label, title))
  303. # df_test['pred_new'] = pd.Series(preds)
  304. # df_test['pred_prob'] = pd.Series(probs)
  305. # # df_test.to_excel(df_path[:-5]+'_predict.xlsx')
  306. # df_test.to_excel(df_path)
  307. def is_houxuan(title, content):
  308. '''
  309. 通过标题和中文内容判断是否属于候选人公示类别
  310. :param title: 公告标题
  311. :param content: 公告正文文本内容
  312. :return: 1 是候选人公示 ;0 不是
  313. '''
  314. if re.search('候选人的?公示|评标结果|评审结果|中标公示', title): # (中标|成交|中选|入围)
  315. if re.search('变更公告|更正公告|废标|终止|答疑|澄清', title):
  316. return 0
  317. return 1
  318. if re.search('候选人的?公示', content[:100]):
  319. if re.search('公示(期|活动)?已经?结束|公示期已满|中标结果公告|中标结果公示|变更公告|更正公告|废标|终止|答疑|澄清', content[:100]):
  320. return 0
  321. return 1
  322. else:
  323. return 0
  324. def channel_predict_batch(df_path):
  325. print('批量预测')
  326. df = pd.read_excel(df_path)
  327. df.fillna('', inplace=True)
  328. df.reset_index(drop=True, inplace=True)
  329. bz = 1024*10*6
  330. total_batch = int((len(df)-1)/bz+1)
  331. for i in range(total_batch):
  332. df_test = copy.deepcopy(df[i*bz:(i+1)*bz])
  333. df_test.reset_index(drop=True, inplace=True)
  334. docs = [[docid, title, content] for docid, title, content in zip(df_test['docid'], df_test['segword_title'], df_test['segword'])]
  335. print('总共%d篇文章'%len(docs))
  336. preds, probs = DocChannel.predict_batch(docs)
  337. # df_test['pred_old'] = df_test['pred_new']
  338. df_test['pred_new'] = pd.Series(preds)
  339. df_test['pred_prob'] = pd.Series(probs)
  340. # df_test['old=new'] = df_test.apply(lambda x:1 if x['pred_old']==x['pred_new'] else 0, axis=1)
  341. # df_test = df_test[df_test.loc[:, 'old=new']==0]
  342. # print(df_test.head(3))
  343. # for idx in df_test.index:
  344. # title = df_test.loc[idx, 'doctitle']
  345. # text = re.sub('[^\u4e00-\u9fa5]', '',df_test.loc[idx, 'segword'])
  346. # try:
  347. # if is_houxuan(title, text)==1:
  348. # df_test.loc[idx, 'pred_new'] = '候选人公示'
  349. # except:
  350. # print('出错了',df_test.loc[idx, 'pred_new'],text)
  351. df_test['pred_new'] = df_test.apply(lambda x:'候选人公示' if x['pred_new']=='中标信息' and is_houxuan(x['doctitle'], re.sub('[^\u4e00-\u9fa5]', '', x['segword']))==1 else x['pred_new'] , axis=1)
  352. df_test.to_excel(df_path[:-5]+'_predict_new_{}.xlsx'.format(i))
  353. print('保存文件成功')
  354. if __name__ == "__main__":
  355. path = 'data/候选人公示.xlsx'
  356. DocChannel = DocChannel()
  357. # channel_predict_batch(path)
  358. for path in ['data/docchannel带数据源2021-04-12_bidi_process.xlsx',
  359. 'data/docchannel带数据源2021-04-13_bidi_process.xlsx',
  360. 'data/docchannel带数据源2021-04-14_bidi_process.xlsx',
  361. 'data/docchannel带数据源2021-04-15_bidi_process.xlsx',
  362. 'data/docchannel带数据源2021-04-16_bidi_process.xlsx']:
  363. # for path in ['data/docchannel带数据源2021-04-12_bidi_process_predict_0.xlsx',
  364. # 'data/docchannel带数据源2021-04-13_bidi_process_predict_0.xlsx',
  365. # # 'data/docchannel带数据源2021-04-14_bidi_process.xlsx',
  366. # 'data/docchannel带数据源2021-04-15_bidi_process_predict_0.xlsx',
  367. # 'data/docchannel带数据源2021-04-16_bidi_process_predict_0.xlsx']:
  368. channel_predict_batch(path)
  369. # df_test = pd.read_excel('data/df_test_公告类型.xlsx')