life_cycle.py 81 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504150515061507150815091510151115121513151415151516151715181519152015211522152315241525152615271528152915301531153215331534153515361537153815391540154115421543154415451546154715481549155015511552155315541555155615571558155915601561156215631564156515661567156815691570157115721573157415751576157715781579158015811582158315841585158615871588
  1. #!/usr/bin/python3
  2. # -*- coding: utf-8 -*-
  3. # @Author : bidikeji
  4. # @Time : 2021/5/11 0011 19:31
  5. import pandas as pd
  6. import numpy as np
  7. import tensorflow as tf
  8. import re
  9. import os
  10. # os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
  11. # os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
  12. import glob
  13. import copy
  14. import pickle
  15. import BiddingKG.dl.interface.Preprocessing as Preprocessing
  16. from BiddingKG.dl.common.Utils import getVocabAndMatrix,getModel_w2v,precision, recall, f1_score
  17. label2key = {
  18. '中标信息': 101,
  19. '业主采购': 113,
  20. '产权交易': 117,
  21. '企业名录': 110,
  22. '企业资质': 111,
  23. '全国工程': 112,
  24. '公告变更': 51,
  25. '土地矿产': 116,
  26. '展会推广': 109,
  27. '拍卖出让': 115,
  28. '招标公告': 52,
  29. '招标文件': 104,
  30. '招标答疑': 103,
  31. '招标预告': 102,
  32. '拟建项目': 108,
  33. '新闻资讯': 107,
  34. '法律法规': 106,
  35. '资审结果': 105,
  36. '采购意向': 114}
  37. key2label = {v:k for k,v in label2key.items()}
  38. word_model = getModel_w2v()
  39. vocab, embedding_matrix = getVocabAndMatrix(word_model, Embedding_size=128)
  40. word_index = {k:v for v,k in enumerate(vocab)}
  41. height, width = embedding_matrix.shape
  42. print('词向量.shape', embedding_matrix.shape)
  43. print('词典大小', len(vocab))
  44. sequen_len = 200#150 200
  45. title_len = 30
  46. sentence_num = 10
  47. keywords = []
  48. for file in glob.glob('data/类别关键词/*.txt'):
  49. with open(file, 'r', encoding='utf-8') as f:
  50. text = f.read()
  51. tmp_kw = [it for it in text.split('\n') if it]
  52. keywords.extend(tmp_kw)
  53. keywordset = sorted(set(keywords), key=lambda x: len(x), reverse=True)
  54. # kws = '资格|资质|预审|后审|审查|入围|意向|预告|预|需求|计划|意见|登记|报建|变更|更正|暂停|暂缓|延期|恢复|撤销|\
  55. # 取消|更改|答疑|补遗|补充|澄清|限价|控制|终止|中止|废标|失败|废置|流标|合同|乙方|受让|中标|中选|成交|指定|选定\
  56. # |结果|候选人|来源|供应商|供货商|入选人|条件|报名'
  57. # kws2 = '拍卖|竞拍|流拍|变卖|土地|用地|地块|宗地|供地|采矿|探矿|出租|租赁|挂牌|招标|遴选|比选|询价|洽谈|采购|工程|项目|货物|供应商|候选人|中标|中选|成交'
  58. # kws = '供货商|候选人|供应商|入选人|项目|选定|预告|中标|成交|补遗|延期|报名|暂缓|结果|意向|出租|补充|合同|限价|比选|指定|工程|废标|取消|中止|流标|资质|资格|地块|招标|采购|货物|租赁|计划|宗地|需求|来源|土地|澄清|失败|探矿|预审|变更|变卖|遴选|撤销|意见|恢复|采矿|更正|终止|废置|报建|流拍|供地|登记|挂牌|答疑|中选|受让|拍卖|竞拍|审查|入围|更改|条件|洽谈|乙方|后审|控制|暂停|用地|询价|预'
  59. kws = '供货商|候选人|供应商|入选人|选定|中标|成交|合同|指定|废标|中止|流标|地块|宗地|土地|澄清|失败|预审|变更|变卖|更正|终止|废置|流拍|供地|挂牌|答疑|中选|受让|拍卖|竞拍|审查|入围|洽谈|乙方|后审|用地'
  60. def get_kw_senten_backup(s, span = 10):
  61. doc_sens = []
  62. tmp = 0
  63. num = 0
  64. for it in re.finditer('|'.join(keywordset), s):
  65. left = s[:it.end()].split()
  66. right = s[it.end():].split()
  67. tmp_seg = s[tmp:it.start()].split()
  68. if len(tmp_seg) > span or tmp == 0:
  69. if len(left) >= span:
  70. doc_sens.append(' '.join(left[-span:] + right[:span]))
  71. else:
  72. doc_sens.append(' '.join(left + right[:(span + span - len(left))]))
  73. tmp = it.end()
  74. num += 1
  75. if num >= sentence_num:
  76. break
  77. if doc_sens == []:
  78. doc_sens.append(s)
  79. return doc_sens
  80. def get_kw_senten(s, span=10):
  81. doc_sens = []
  82. tmp = 0
  83. num = 0
  84. end_idx = 0
  85. for it in re.finditer(kws, s): #'|'.join(keywordset)
  86. left = s[end_idx:it.end()].split()
  87. right = s[it.end():].split()
  88. tmp_seg = s[tmp:it.start()].split()
  89. if len(tmp_seg) > span or tmp == 0:
  90. doc_sens.append(' '.join(left[-span:] + right[:span]))
  91. print(it.group(0), doc_sens[-1])
  92. end_idx = it.end()+1+len( ' '.join(right[:span]))
  93. tmp = it.end()
  94. num += 1
  95. if num >= sentence_num:
  96. break
  97. if doc_sens == []:
  98. doc_sens.append(s)
  99. return doc_sens
  100. def word2id(wordlist, max_len=sequen_len):
  101. # words = [word for word in wordlist if word.isalpha()]
  102. ids = [word_index.get(w, 0) for w in wordlist]
  103. # if re.search('[\u4e00-\u9fa5]', w) and w in word_index]
  104. ids = ids[:max_len] if len(ids)>=max_len else ids+[0]*(max_len-len(ids))
  105. assert len(ids)==max_len
  106. return ids
  107. def cut_words(filename):
  108. # df = pd.read_excel('/data/esa_sdk/text_classifier_2020_09_28/channel_data/docchannel带数据源0413_filter.xlsx')
  109. # df = pd.read_excel('/data/esa_sdk/text_classifier_2020_09_28/channel_data/docchannel带数据源_predict3.xlsx')
  110. df = pd.read_excel('data/{}.xlsx'.format(filename))
  111. df.fillna('', inplace=True)
  112. df.reset_index(drop=True, inplace=True)
  113. segword_list = []
  114. segword_title = []
  115. bz = 1024
  116. # articles = [[doc_id, html,"",doc_id, title] for doc_id, html, title in zip(df['docid'],df['dochtmlcon'],df['doctitle'])]
  117. # articles_title = [[doc_id, title,"",doc_id, title] for doc_id, html, title in zip(df['docid'],df['dochtmlcon'],df['doctitle'])]
  118. for i in df.index:
  119. articles = [[df.loc[i, 'docid'], df.loc[i, 'dochtmlcon'], "", df.loc[i, 'docid'], df.loc[i, 'doctitle']]]
  120. articles_title = [[df.loc[i, 'docid'], df.loc[i, 'doctitle'], "", df.loc[i, 'docid'], df.loc[i, 'doctitle']]]
  121. # list_articles, list_sentences, list_entitys, _ = Preprocessing.get_preprocessed(articles[i*bz:(i+1)*bz], useselffool=True)
  122. cost_time = dict()
  123. try:
  124. list_articles = Preprocessing.get_preprocessed_article(articles, cost_time)
  125. list_sentences = Preprocessing.get_preprocessed_sentences(list_articles, True, cost_time)
  126. for doc in list_sentences:
  127. sen_words = [sen.tokens for sen in doc]
  128. words = [it for sen in sen_words for it in sen]
  129. segword_list.append(' '.join(words))
  130. except:
  131. print('正文处理出错', df.loc[i, 'docid'])
  132. segword_list.append('')
  133. # list_articles_title, list_sentences_title, list_entitys_title, _ = Preprocessing.get_preprocessed(articles_title[i*bz:(i+1)*bz], useselffool=True)
  134. cost_time = dict()
  135. try:
  136. list_articles_title = Preprocessing.get_preprocessed_article(articles_title, cost_time)
  137. list_sentences_title = Preprocessing.get_preprocessed_sentences(list_articles_title, True, cost_time)
  138. for doc in list_sentences_title:
  139. sen_words = [sen.tokens for sen in doc]
  140. words = [it for sen in sen_words for it in sen]
  141. segword_title.append(' '.join(words))
  142. except:
  143. print('标题处理出错', df.loc[i, 'docid'])
  144. segword_title.append('')
  145. print(i)
  146. df['segword'] = segword_list
  147. df['segword_title'] = segword_title
  148. print(df.head(3))
  149. # df.to_excel('/data/esa_sdk/text_classifier_2020_09_28/channel_data/docchannel带数据源0413_filter_bidi_process.xlsx')
  150. # df.to_excel('/data/esa_sdk/text_classifier_2020_09_28/channel_data/docchannel带数据源_bidi_process.xlsx')
  151. df.to_excel('data/{}_bidi_process.xlsx'.format(filename))
  152. print('')
  153. def split_train_test(df, split_rate=0.1):
  154. import copy
  155. train = []
  156. test = []
  157. df_train = pd.DataFrame()
  158. df_test = pd.DataFrame()
  159. for lb in set(df['label']):
  160. df_tmp = copy.deepcopy(df[df.loc[:, 'label']==lb])
  161. df_tmp = df_tmp.sample(frac=1)
  162. train.append(df_tmp[int(split_rate*len(df_tmp)):])
  163. test.append(df_tmp[:int(split_rate*len(df_tmp))])
  164. df_train = df_train.append(train, ignore_index=True)
  165. df_test = df_test.append(test, ignore_index=True)
  166. return df_train.sample(frac=1), df_test.sample(frac=1)
  167. def data_process(df, label2id):
  168. df.fillna('', inplace=True)
  169. datas_title = []
  170. datas = []
  171. labels = []
  172. doc_content = []
  173. doc_title = []
  174. for segword, segword2, label in zip(df['segword_title'], df['segword'], df['label']):
  175. segword = segword.replace(' 中 选 ', ' 中选 ').replace(' 中 标 ', ' 中标 ').replace(' 补 遗 ', ' 补遗 ').replace(' 更 多 ',' 更多 ').replace(' 中 号 ',' 中标 ').replace(' 中 选人 ',' 中选人 ')
  176. segword = [w for w in segword.split() if w.isalpha() and re.search('[a-zA-Z]', w)==None and w in word_index]
  177. datas_title.append(word2id(segword[-title_len:], max_len=title_len))
  178. segword2 = segword2.replace(' 中 选 ', ' 中选 ').replace(' 中 标 ', ' 中标 ').replace(' 补 遗 ', ' 补遗 ').replace(' 更 多 ',' 更多 ').replace(' 中 号 ',' 中标 ').replace(' 中 选人 ',' 中选人 ')
  179. segword2 = [w for w in segword2.split() if w.isalpha() and re.search('[a-zA-Z]', w) == None and w in word_index]
  180. datas.append(word2id(segword2, max_len=sequen_len))
  181. # labels.append(label2id[label])
  182. if label in label2id:
  183. labels.append(label2id[label])
  184. else:
  185. print('测试状态:%s 不在标签列'%label)
  186. labels.append(label2id.get(label, 0))
  187. doc_content.append(' '.join(segword2[:sequen_len]))
  188. doc_title.append(' '.join(segword[-title_len:]))
  189. onehot = np.zeros((len(labels), len(label2id)))
  190. df['content_input'] = pd.Series(doc_content)
  191. df['title_input'] = pd.Series(doc_title)
  192. for i in range(len(onehot)):
  193. onehot[i][labels[i]] = 1
  194. return np.array(datas), onehot, np.array(datas_title), df
  195. def data_process_sentence(df, label2id):
  196. df.fillna('', inplace=True)
  197. df.reset_index(drop=True, inplace=True)
  198. datas_title = []
  199. datas = []
  200. labels = []
  201. sentence_input = []
  202. for segword, segword2, label in zip(df['segword_title'], df['segword'], df['label']):
  203. # segword = ' '.join([it for it in segword.split() if it.isalpha()][:title_len])
  204. # segword2 = ' '.join([it for it in segword2.split() if it.isalpha()][:2000])
  205. segword = re.sub('[^\s\u4e00-\u9fa5]', '', segword)
  206. segword2 = re.sub('[^\s\u4e00-\u9fa5]', '', segword2)
  207. segword2 = segword2.replace(' 中 选 ', ' 中选 ').replace(' 中 标 ', ' 中标 ').replace(' 补 遗 ', ' 补遗 ').\
  208. replace(' 更 多','').replace(' 更多', '').replace(' 中 号 ',' 中标 ').replace(' 中 选人 ',' 中选人 ').\
  209. replace(' 点击 下载 查看','').replace(' 咨询 报价 请 点击', '').replace('终结', '终止').replace('废除','废标')
  210. doc_word_list = segword2.split()
  211. # doc_sens = ' '.join(doc_word_list[:sequen_len])
  212. if len(doc_word_list) > sequen_len/2:
  213. doc_sens = get_kw_senten(' '.join(doc_word_list[100:500]))
  214. # doc_sens = ' '.join(doc_word_list[:100]+doc_sens)
  215. doc_sens = ' '.join(doc_word_list[:100]) + '\n' +'\n'.join(doc_sens)
  216. else:
  217. doc_sens = ' '.join(doc_word_list[:sequen_len])
  218. sentence_input.append(doc_sens)
  219. # sentence_input.append(' '.join(doc_sens))
  220. # if len(doc_sens)<1:
  221. # continue
  222. # assert len(doc_ids) == sentence_num
  223. # assert len(doc_ids[-1]) == sequen_len
  224. # datas.append(word2id(' '.join(doc_sens).split(), max_len=sequen_len))
  225. datas.append(word2id(doc_sens.split(), max_len=sequen_len))
  226. datas_title.append(word2id(segword.split(), max_len=title_len))
  227. # labels.append(label2id[label])
  228. if label in label2id:
  229. labels.append(label2id[label])
  230. else:
  231. print('测试状态:%s 不在标签列'%label)
  232. labels.append(label2id.get(label, 0))
  233. df['content_input'] = pd.Series(sentence_input)
  234. # onehot = np.zeros((len(labels), len(label2id)))
  235. # for i in range(len(onehot)):
  236. # onehot[i][labels[i]] = 1
  237. # return np.array(datas), onehot, np.array(datas_title), df
  238. return datas, labels, datas_title, df
  239. def data_process_backup(df, label2id):
  240. # aticles = [(id, text) for id, text in zip(df['docid'], df['dochtml'])]
  241. # datas, _ = clean_word_with_tokenizer(aticles, remove_word,tokenizer)
  242. # datas = [word2id(segword.split()) for segword in df['segword']]
  243. datas_title = []
  244. for segword in df['segword_title']:
  245. if isinstance(segword, str):
  246. segword = segword.replace(' 中 选 ', ' 中选 ').replace(' 中选人 ', ' 中选 人 ')
  247. datas_title.append(word2id(segword.split()[-title_len:], max_len=title_len))
  248. else:
  249. datas_title.append(word2id([], max_len=title_len))
  250. datas = []
  251. for segword, segword2 in zip(df['segword_title'], df['segword']):
  252. # if isinstance(segword, str) and segword not in segword2:
  253. # segword = segword.replace(' 中 选 ', ' 中选 ').replace(' 中选人 ', ' 中选 人 ')
  254. # segword2 = segword2.replace(' 中 选 ', ' 中选 ').replace(' 中选人 ', ' 中选 人 ')
  255. # datas.append(word2id((segword+' '+segword2).split()))
  256. # else:
  257. segword2 = segword2.replace(' 中 选 ', ' 中选 ').replace(' 中选人 ', ' 中选 人 ')
  258. datas.append(word2id(segword2.split()))
  259. labels = list(df['label'].apply(lambda x:label2id[x]))
  260. onehot = np.zeros((len(labels), len(label2id)))
  261. for i in range(len(onehot)):
  262. onehot[i][labels[i]] = 1
  263. return np.array(datas), onehot, np.array(datas_title)
  264. def attention(inputs, mask):
  265. with tf.variable_scope('attention', reuse=tf.AUTO_REUSE):
  266. hidden_size = inputs.shape[2].value
  267. u = tf.get_variable(name='u', shape=[hidden_size], dtype=tf.float32, initializer=tf.keras.initializers.glorot_normal())
  268. with tf.name_scope('v'):
  269. v = tf.tanh(inputs)
  270. vu = tf.tensordot(v,u, axes=1, name='vu')
  271. vu += tf.cast(mask, dtype=tf.float32)*(-10000)
  272. alphas = tf.nn.softmax(vu, name='alphas')
  273. output = tf.reduce_sum(inputs*tf.expand_dims(alphas, -1), 1)
  274. output = tf.tanh(output, name='att_out')
  275. return output, alphas
  276. def attention_new(inputs, mask):
  277. w = tf.get_variable('w', shape=(inputs.shape[2].value, 1),
  278. dtype=tf.float32, initializer=tf.random_normal_initializer())
  279. b = tf.get_variable('b', shape=(inputs.shape[1].value, 1),
  280. dtype=tf.float32, initializer=tf.zeros_initializer())
  281. u = tf.get_variable('u', shape=(inputs.shape[1].value, inputs.shape[1].value),
  282. dtype=tf.float32, initializer=tf.random_normal_initializer())
  283. et = tf.squeeze(tf.tanh(tf.tensordot(inputs, w, axes=1)+b), axis=-1)
  284. at = tf.matmul(et, u)
  285. at = tf.add(at, tf.cast(mask, dtype=tf.float32) * (-10000))
  286. at = tf.exp(at)
  287. at_sum = tf.cast(tf.reduce_sum(at, axis=1, keepdims=True)+1e-10, tf.float32)
  288. at = tf.divide(at, at_sum, name='alphas')
  289. alpha = tf.expand_dims(at, axis=-1)
  290. ot = alpha*inputs
  291. return tf.reduce_sum(ot, axis=1), at
  292. def attention_han(inputs,
  293. initializer=tf.contrib.layers.xavier_initializer(),
  294. activation_fn=tf.tanh, scope=None):
  295. """
  296. Performs task-specific attention reduction, using learned
  297. attention context vector (constant within task of interest).
  298. Args:
  299. inputs: Tensor of shape [batch_size, units, input_size]
  300. `input_size` must be static (known)
  301. `units` axis will be attended over (reduced from output)
  302. `batch_size` will be preserved
  303. output_size: Size of output's inner (feature) dimension
  304. Returns:
  305. outputs: Tensor of shape [batch_size, output_dim].
  306. """
  307. assert len(inputs.get_shape()) == 3 and inputs.get_shape()[-1].value is not None
  308. output_size = inputs.shape[-1].value
  309. with tf.variable_scope(scope or 'attention') as scope:
  310. attention_context_vector = tf.get_variable(name='attention_context_vector',
  311. shape=[output_size],
  312. initializer=initializer,
  313. dtype=tf.float32)
  314. input_projection = tf.contrib.layers.fully_connected(inputs, output_size,
  315. activation_fn=activation_fn,
  316. scope=scope)
  317. vector_attn = tf.reduce_sum(tf.multiply(input_projection, attention_context_vector), axis=2, keepdims=True)
  318. attention_weights = tf.nn.softmax(vector_attn, axis=1)
  319. alpha = tf.squeeze(attention_weights, axis=-1, name='alphas')
  320. weighted_projection = tf.multiply(input_projection, attention_weights)
  321. outputs = tf.reduce_sum(weighted_projection, axis=1)
  322. return outputs, alpha
  323. def lstm_att_model(class_num):
  324. embed_dim = 100
  325. lstm_dim = 512 # 256
  326. # sequen_len = 150
  327. with tf.name_scope('inputs'):
  328. inputs = tf.placeholder(dtype=tf.int32, shape=[None, sequen_len], name='inputs')
  329. # labels = tf.placeholder(dtype=tf.float32, shape=[None, class_num], name='labels')
  330. labels_input = tf.placeholder(dtype=tf.int32, shape=[None], name='labels')
  331. labels = tf.one_hot(labels_input, depth=class_num)
  332. prob = tf.placeholder_with_default(input=1.0, shape=[], name='dropout')
  333. mask = tf.equal(inputs, 0, name='mask')
  334. title = tf.placeholder(dtype=tf.int32, shape=[None, title_len], name='title')
  335. mask_title = tf.equal(title, 0, name='mask_title')
  336. with tf.variable_scope('embedding'):
  337. w = tf.Variable(initial_value=embedding_matrix, dtype=tf.float32)
  338. # w = tf.get_variable(name='embded_w', shape=[height, width], dtype=tf.float32, initializer=tf.contrib.layers.xavier_initializer())
  339. embedding = tf.nn.embedding_lookup(w, inputs)
  340. # embedding = tf.nn.dropout(embedding, prob)
  341. title_emb = tf.nn.embedding_lookup(w, title)
  342. # title_emb = tf.nn.dropout(title_emb, prob)
  343. with tf.variable_scope('net'):
  344. forward = tf.contrib.rnn.BasicLSTMCell(lstm_dim, state_is_tuple=True, dtype=tf.float32)
  345. backward = tf.contrib.rnn.BasicLSTMCell(lstm_dim, state_is_tuple=True, dtype=tf.float32)
  346. # forward = tf.nn.rnn_cell.DropoutWrapper(forward, output_keep_prob=prob)
  347. # backward = tf.nn.rnn_cell.DropoutWrapper(backward, output_keep_prob=prob)
  348. outputs,state = tf.nn.bidirectional_dynamic_rnn(
  349. forward,
  350. backward,
  351. embedding,
  352. sequence_length= tf.cast(tf.reduce_sum(tf.sign(tf.abs(inputs)), reduction_indices=1), tf.int32),
  353. dtype=tf.float32
  354. )
  355. # bi_output = tf.concat(outputs, axis=-1)
  356. bi_output = tf.add(outputs[0], outputs[1])
  357. bi_output = tf.nn.dropout(bi_output, keep_prob=0.5)
  358. att_output, alpha = attention(bi_output, mask)
  359. # att_output, alpha = attention_new(bi_output, mask)
  360. # att_output, alpha = attention_han(bi_output)
  361. # drop_content = tf.nn.dropout(att_output, keep_prob=prob)
  362. output_title, state_title = tf.nn.bidirectional_dynamic_rnn(
  363. forward,
  364. backward,
  365. title_emb,
  366. sequence_length=tf.cast(tf.reduce_sum(tf.sign(tf.abs(title)), reduction_indices=1), tf.int32),
  367. dtype=tf.float32
  368. )
  369. # bi_title = tf.concat(output_title, axis=-1)[:,-1,:]
  370. bi_title = tf.add(output_title[0], output_title[1])#[:,-1,:]
  371. bi_title = tf.nn.dropout(bi_title, keep_prob=prob)
  372. # bi_title = tf.concat(output_title, axis=-1)
  373. bi_title, alpha_title = attention(bi_title, mask_title)
  374. drop_output = tf.concat([bi_title, att_output], axis=-1)
  375. # drop_output = tf.add(bi_title, att_output)
  376. # drop_output = att_output
  377. with tf.variable_scope('output'):
  378. softmax_w = tf.get_variable('softmax_w', shape=[lstm_dim*2, class_num], dtype=tf.float32) #[lstm_dim*2, class_num]
  379. softmax_output = tf.nn.softmax(tf.matmul(drop_output, softmax_w), name='softmax')
  380. logit = tf.argmax(softmax_output, axis=-1, name='logit')
  381. with tf.name_scope(name='loss'):
  382. loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2(labels=labels, logits=softmax_output), name='loss')
  383. with tf.name_scope(name='metric'):
  384. _p = precision(labels, softmax_output)
  385. _r = recall(labels, softmax_output)
  386. _f1 = f1_score(labels, softmax_output)
  387. with tf.name_scope(name='train_op'):
  388. optimizer = tf.train.AdamOptimizer(learning_rate=0.0007)
  389. # optimizer = tf.train.AdadeltaOptimizer(learning_rate=0.1)# tf.train.GradientDescentOptimizer()# tf.train.AdadeltaOptimizer()
  390. global_step = tf.Variable(0, trainable=False)
  391. grads_vars = optimizer.compute_gradients(loss=loss)
  392. capped_grads_vars = [[tf.clip_by_value(g, -5, 5), v] for g,v in grads_vars]
  393. train_op = optimizer.apply_gradients(capped_grads_vars, global_step)
  394. return inputs, labels_input, prob, logit, loss, train_op, _p, _r, _f1, alpha, title, softmax_output #,alpha_title
  395. def lstm_att_model_withoutEmb(class_num):
  396. embed_dim = 100
  397. lstm_dim = 512 # 256
  398. # sequen_len = 150
  399. with tf.name_scope('inputs'):
  400. content_emb = tf.placeholder(dtype=tf.float32, shape=[None, sequen_len, width], name='inputs')
  401. # labels = tf.placeholder(dtype=tf.float32, shape=[None, class_num], name='labels')
  402. labels_input = tf.placeholder(dtype=tf.int32, shape=[None], name='labels')
  403. labels = tf.one_hot(labels_input, depth=class_num)
  404. prob = tf.placeholder_with_default(input=1.0, shape=[], name='dropout')
  405. mask = tf.placeholder(dtype=tf.int32, shape=[None, sequen_len], name='mask')
  406. doc_length = tf.cast(tf.reduce_sum(1-mask, reduction_indices=1), tf.int32)
  407. title_emb = tf.placeholder(dtype=tf.float32, shape=[None, title_len, width], name='title')
  408. mask_title = tf.placeholder(dtype=tf.int32, shape=[None, title_len], name='mask_title')
  409. title_length = tf.cast(tf.reduce_sum(1-mask_title, reduction_indices=1), tf.int32)
  410. # with tf.variable_scope('embedding'):
  411. # w = tf.Variable(initial_value=embedding_matrix, dtype=tf.float32)
  412. # # w = tf.get_variable(name='embded_w', shape=[height, width], dtype=tf.float32, initializer=tf.contrib.layers.xavier_initializer())
  413. # embedding = tf.nn.embedding_lookup(w, inputs)
  414. # # embedding = tf.nn.dropout(embedding, prob)
  415. #
  416. # title_emb = tf.nn.embedding_lookup(w, title)
  417. # title_emb = tf.nn.dropout(title_emb, prob)
  418. with tf.variable_scope('net'):
  419. forward = tf.contrib.rnn.BasicLSTMCell(lstm_dim, state_is_tuple=True, dtype=tf.float32)
  420. backward = tf.contrib.rnn.BasicLSTMCell(lstm_dim, state_is_tuple=True, dtype=tf.float32)
  421. # forward = tf.nn.rnn_cell.DropoutWrapper(forward, output_keep_prob=prob)
  422. # backward = tf.nn.rnn_cell.DropoutWrapper(backward, output_keep_prob=prob)
  423. outputs,state = tf.nn.bidirectional_dynamic_rnn(
  424. forward,
  425. backward,
  426. content_emb,
  427. sequence_length= doc_length,
  428. dtype=tf.float32
  429. )
  430. # bi_output = tf.concat(outputs, axis=-1)
  431. bi_output = tf.add(outputs[0], outputs[1])
  432. bi_output = tf.nn.dropout(bi_output, keep_prob=prob)
  433. att_output, alpha = attention(bi_output, mask)
  434. # att_output, alpha = attention_new(bi_output, mask)
  435. # att_output, alpha = attention_han(bi_output)
  436. # drop_content = tf.nn.dropout(att_output, keep_prob=prob)
  437. output_title, state_title = tf.nn.bidirectional_dynamic_rnn(
  438. forward,
  439. backward,
  440. title_emb,
  441. sequence_length= title_length,
  442. dtype=tf.float32
  443. )
  444. # bi_title = tf.concat(output_title, axis=-1)[:,-1,:]
  445. bi_title = tf.add(output_title[0], output_title[1])#[:,-1,:]
  446. bi_title = tf.nn.dropout(bi_title, keep_prob=prob)
  447. # bi_title = tf.concat(output_title, axis=-1)
  448. bi_title, alpha_title = attention(bi_title, mask_title)
  449. drop_output = tf.concat([bi_title, att_output], axis=-1)
  450. # drop_output = tf.add(bi_title, att_output)
  451. # drop_output = att_output
  452. with tf.variable_scope('output'):
  453. softmax_w = tf.get_variable('softmax_w', shape=[lstm_dim*2, class_num], dtype=tf.float32) #[lstm_dim*2, class_num]
  454. softmax_output = tf.nn.softmax(tf.matmul(drop_output, softmax_w), name='softmax')
  455. logit = tf.argmax(softmax_output, axis=-1, name='logit')
  456. with tf.name_scope(name='loss'):
  457. loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2(labels=labels, logits=softmax_output), name='loss')
  458. with tf.name_scope(name='metric'):
  459. _p = precision(labels, softmax_output)
  460. _r = recall(labels, softmax_output)
  461. _f1 = f1_score(labels, softmax_output)
  462. with tf.name_scope(name='train_op'):
  463. optimizer = tf.train.AdamOptimizer(learning_rate=0.001)
  464. # optimizer = tf.train.AdadeltaOptimizer(learning_rate=0.1)# tf.train.GradientDescentOptimizer()# tf.train.AdadeltaOptimizer()
  465. global_step = tf.Variable(0, trainable=False)
  466. grads_vars = optimizer.compute_gradients(loss=loss)
  467. capped_grads_vars = [[tf.clip_by_value(g, -5, 5), v] for g,v in grads_vars]
  468. train_op = optimizer.apply_gradients(capped_grads_vars, global_step)
  469. return content_emb,mask, labels_input, prob, logit, loss, train_op, _p, _r, _f1, alpha, title_emb,mask_title, softmax_output #,alpha_title
  470. def train():
  471. # import glob
  472. # kw_dic = {}
  473. # for file in glob.glob('data/类别关键词/*.txt'):
  474. # with open(file, 'r', encoding='utf-8') as f:
  475. # text = f.read()
  476. # tmp_kw = sorted(set([it for it in text.split('\n') if it]), key=lambda x: len(x), reverse=True)
  477. # lb = file.split('_')[-1][:-4]
  478. # kw_dic[lb] = tmp_kw
  479. # # print(lb, tmp_kw[:3])
  480. # def find_kw(lb, s):
  481. # kw = []
  482. # if lb in kw_dic:
  483. # for it in re.finditer('|'.join(kw_dic[lb]), s):
  484. # kw.append(it.group())
  485. # elif lb == '其他公告':
  486. # for it in re.finditer('|'.join(kw_dic['新闻资讯']), s):
  487. # kw.append(it.group())
  488. # return ' '.join(kw)
  489. # def df_filter(df, num_per_sour=30):
  490. # '''过滤没有类别关键词的文章,每个数据源每个类别最多取30篇文章'''
  491. # df = df[df.loc[:, 'lbkw>2']==1]
  492. # l = []
  493. # for source in set(df['web_source_no']):
  494. # df_source = df[df.loc[:, 'web_source_no']==source]
  495. # for lb in set(df_source['label']):
  496. # df_tmp = df_source[df_source.loc[:, 'label']==lb]
  497. # if len(df_tmp) > num_per_sour:
  498. # l.append(df_tmp.sample(num_per_sour))
  499. # elif len(df_tmp)>1:
  500. # l.append(df_tmp)
  501. # df_new = pd.DataFrame()
  502. # df_new = df_new.append(l, ignore_index=True)
  503. # return df_new
  504. # df_l = []
  505. # df = pd.DataFrame()
  506. # for file in glob.glob('data/docchannel带数据源2021-04-12-16抽取数据*'):
  507. # df_tmp = pd.read_excel(file)
  508. # df_l.append(df_tmp)
  509. # print(file, len(df_tmp))
  510. # # df = pd.read_excel('/data/esa_sdk/text_classifier_2020_09_28/channel_data/docchannel带数据源0413_filter_bidi_process.xlsx')
  511. # # df1 = pd.read_excel('data/docchannel带数据源0419_source_filter_bidi_process_predict.xlsx')
  512. # # df = df.append(df1, ignore_index=True)
  513. # df = df.append(df_l, ignore_index=True)
  514. # print(df.head(2))
  515. # df = df[df.loc[:, 'new=label']==1]
  516. # print('合并后数据总数:%d'%len(df))
  517. # import gc
  518. # del df_l
  519. # print(gc.collect())
  520. #
  521. # df.drop_duplicates(subset='segword', inplace=True)
  522. # df.dropna(subset=['segword'], inplace=True)
  523. # df.reset_index(drop=True, inplace=True)
  524. # df.fillna('', inplace=True)
  525. # if 'relabel' in df.columns:
  526. # df['label'] = df.apply(lambda x:x['relabel'] if x['relabel'] not in ['', 1] else x['label'], axis=1)
  527. # df['label'] = df['label'].apply(lambda x:'新闻资讯' if x=='其他公告' else x)
  528. # print('更新 label 完成')
  529. # print(df.head(5))
  530. # df = df[df.loc[:, 'label']!='招标文件']
  531. #
  532. # df['类别关键词'] = df.apply(lambda x: find_kw(x['label'], x['segword_title'] + x['segword']), axis=1)
  533. # df['lbkw>2'] = df['类别关键词'].apply(lambda x: 1 if len(x) > 5 else 0)
  534. # df = df_filter(df, num_per_sour=10)
  535. # print('过滤后数据总数:%d'%len(df))
  536. # lb_path = 'data/id2label.pkl'
  537. # if os.path.exists(lb_path):
  538. # with open(lb_path, 'rb') as f:
  539. # id2label = pickle.load(f)
  540. # else:
  541. # labels = sorted(list(set(df['label'])))
  542. # id2label = {k:v for k,v in enumerate(labels)}
  543. # with open(lb_path, 'wb') as f:
  544. # pickle.dump(id2label, f)
  545. # lb = ['采购意向', '招标预告', '招标公告', '招标答疑', '公告变更', '资审结果', '招标补充', '中标信息', '合同公告', '废标公告']
  546. lb = ['采购意向', '招标预告', '招标公告', '招标答疑', '公告变更', '资审结果', '中标信息', '合同公告', '废标公告']
  547. id2label = {k:v for k,v in enumerate(lb)}
  548. label2id = {v:k for k,v in id2label.items()}
  549. # assert set(label2id)==set(df['label'])
  550. # # df1 = pd.read_excel('data/按数据源类别抽取重新标注数据_predict_类型预测.xlsx')
  551. # # df = pd.read_excel('data/公告类型标注数据2021-05-26_bidi_process_predict_类型预测.xlsx')
  552. # # df = df.append(df1, ignore_index=True)
  553. # # df = df[df.loc[:, 'relabel'].isin(lb)]
  554. # # df.drop_duplicates(subset=['segword'], inplace=True)
  555. # # df.reset_index(drop=True, inplace=True)
  556. # # if 'relabel' in df.columns:
  557. # # df['relabel'] = df['relabel'].apply(lambda x:'招标答疑' if x=='招标补充' else x)
  558. # # df['label'] = df.apply(lambda x:x['relabel'] if x['relabel'] not in ['', 1, 0] else x['label'], axis=1)
  559. # # df = df[df.loc[:, 'relabel'].isin(lb)]
  560. # # df.dropna(subset=['segword'], inplace=True)
  561. # # df_train , df_test = split_train_test(df, split_rate=0.2)
  562. # # df_train.reset_index(drop=True, inplace=True)
  563. # # df_test.reset_index(drop=True, inplace=True)
  564. # # df_train.to_excel('data/df_train.xlsx', columns=['segword', 'segword_title', 'label'])
  565. # # df_test.to_excel('data/df_test.xlsx')
  566. #
  567. # df_train = pd.read_excel('data/df_train.xlsx')
  568. # # df_train = df_train.append(df, ignore_index=True)
  569. # # df_train = df_train[:20000]
  570. # df_train = df_train.sample(frac=1)
  571. df_test = pd.read_excel('data/df_test.xlsx')
  572. df_test = df_test.sample(frac=1)
  573. # assert set(df_train['label'])==set(label2id)
  574. # print(df_train.head(3))
  575. # data_train, label_train, title_train, df_train = data_process(df_train, label2id=label2id) # df_train
  576. # data_test, label_test, title_test, df_test = data_process(df_test, label2id=label2id) # df_test
  577. # data_train, label_train, title_train, df_train = data_process_sentence(df_train, label2id=label2id) # df_train
  578. data_test, label_test, title_test, df_test = data_process_sentence(df_test, label2id=label2id) # df_test
  579. # print('data_tran.shape', data_train.shape, label_train.shape)
  580. print('word_index大小 :',len(word_index), ',' in word_index)
  581. file_num = 4# int((len(data_train)-1)/10000)+1
  582. # for i in range(file_num):
  583. # with open('data/train_data/data_train{}.pkl'.format(i), 'wb') as f:
  584. # pickle.dump(data_train[i*10000:(i+1)*10000], f)
  585. # with open('data/train_data/title_train{}.pkl'.format(i), 'wb') as f:
  586. # pickle.dump(title_train[i*10000:(i+1)*10000], f)
  587. # with open('data/train_data/label_train{}.pkl'.format(i), 'wb') as f:
  588. # pickle.dump(label_train[i*10000:(i+1)*10000], f)
  589. import gc
  590. import time
  591. # del df_train
  592. # del df
  593. # del data_train
  594. # del label_train
  595. # del title_train
  596. del df_test
  597. print('清除内存',gc.collect())
  598. time.sleep(1)
  599. print('清除内存', gc.collect())
  600. # word_index, tokenizer, embedding_matrix = get_embedding()
  601. inputs, labels, prob, logit, loss, train_op, _p, _r, _f1, alpha, title, softmax_output = lstm_att_model(
  602. len(id2label))
  603. # gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.55)
  604. # config = tf.ConfigProto(gpu_options=gpu_options)
  605. # config = tf.ConfigProto(allow_soft_placement=True)
  606. # config.gpu_options.per_process_gpu_memory_fraction = 0.45
  607. # config.gpu_options.allow_growth = True
  608. batch_size = 128
  609. min_loss = 10
  610. train_losses = []
  611. val_losses = []
  612. max_f1 = 0
  613. with tf.Session() as sess: #config=config
  614. sess.run(tf.global_variables_initializer())
  615. saver = tf.train.Saver()
  616. print(alpha)
  617. # saver.restore(sess, 'model/channel_foolcut_title_lstm_content_att_concat0607_adadelta.ckpt')
  618. saver.restore(sess, 'model/channel_foolcut_title_lstm_content_att_concat0607_adam.ckpt')
  619. for epoch in range(80):
  620. batch_loss = []
  621. batch_f1 = []
  622. # tensor_name_list = [tensor.name for tensor in tf.get_default_graph().as_graph_def().node]
  623. # print('当前节点数量',len(tensor_name_list))
  624. for i in range(file_num):
  625. with open('data/train_data/data_train{}.pkl'.format(i), 'rb') as f:
  626. data_train = pickle.load(f)
  627. with open('data/train_data/title_train{}.pkl'.format(i), 'rb') as f:
  628. title_train = pickle.load(f)
  629. with open('data/train_data/label_train{}.pkl'.format(i), 'rb') as f:
  630. label_train = pickle.load(f)
  631. for i in range(int((len(data_train) - 1) / batch_size) + 1):
  632. _, loss_, logit_, p, r, f1 = sess.run([train_op, loss, logit, _p, _r, _f1],
  633. feed_dict={
  634. inputs: data_train[i * batch_size:(i + 1) * batch_size],
  635. title: title_train[i * batch_size:(i + 1) * batch_size],
  636. labels: label_train[i * batch_size:(i + 1) * batch_size],
  637. prob: 0.5}
  638. # feed_dict={
  639. # inputs: np.array(data_train[i * batch_size:(i + 1) * batch_size]),
  640. # title: np.array(title_train[i * batch_size:(i + 1) * batch_size]),
  641. # labels: label_train[i * batch_size:(i + 1) * batch_size],
  642. # prob: 0.5}
  643. )
  644. # print(loss_, p, r, f1)
  645. batch_f1.append(f1)
  646. batch_loss.append(loss_)
  647. print('训练 平均损失:%.4f, 平均f1:%.4f' % (np.mean(batch_loss), np.mean(batch_f1)))
  648. train_losses.append(np.mean(batch_loss))
  649. batch_loss = []
  650. batch_f1 = []
  651. for i in range(int((len(data_test) - 1) / batch_size) + 1):
  652. loss_, p, r, f1 = sess.run([loss, _p, _r, _f1],
  653. feed_dict={inputs: data_test[i * batch_size:(i + 1) * batch_size],
  654. title: title_test[i * batch_size:(i + 1) * batch_size],
  655. labels: label_test[i * batch_size:(i + 1) * batch_size],
  656. prob: 1}
  657. # feed_dict={inputs: np.array(data_test[i * batch_size:(i + 1) * batch_size]),
  658. # title: np.array(title_test[i * batch_size:(i + 1) * batch_size]),
  659. # labels: label_test[i * batch_size:(i + 1) * batch_size],
  660. # prob: 1}
  661. )
  662. # print('val_loss, p, r, f1:', loss_, p, r, f1)
  663. batch_f1.append(f1)
  664. batch_loss.append(loss_)
  665. print('第%d轮,val 平均损失:%.4f, 平均f1:%.4f' % (epoch, np.mean(batch_loss), np.mean(batch_f1)))
  666. val_losses.append(np.mean(batch_loss))
  667. if min_loss > np.mean(batch_loss): # max_f1<np.mean(batch_f1) and
  668. max_f1 = np.mean(batch_f1)
  669. min_loss = np.mean(batch_loss)
  670. saver.save(sess,
  671. 'model/channel_foolcut_title_lstm_content_att_concat0607_adam.ckpt') #0416 # channel_title+content_xavier_emb.ckpt channel_title+content
  672. print('第%d轮,loss:%.4f, f1:%.4f 模型保存成功! ' % (epoch, np.mean(batch_loss), np.mean(batch_f1))) #concat0521
  673. # channel_foolcut_title_lstm_content_att_concat0607_adadelta
  674. from matplotlib import pyplot
  675. with open('data/train_loss.pkl', 'wb') as f:
  676. pickle.dump(train_losses, f)
  677. with open('data/val_loss.pkl', 'wb') as f:
  678. pickle.dump(val_losses, f)
  679. # pyplot.plot(train_losses)
  680. # pyplot.plot(val_losses)
  681. # pyplot.title('train and val loss')
  682. # pyplot.ylabel('loss')
  683. # pyplot.xlabel('epoch')
  684. # pyplot.legend(['train', 'val'], loc='upper right')
  685. # pyplot.show()
  686. def predict():
  687. batch_size = 512
  688. lb_path = 'data/id2label.pkl'
  689. # lb = ['采购意向', '招标预告', '招标公告', '招标答疑', '公告变更', '资审结果', '招标补充', '中标信息', '合同公告', '废标公告']
  690. lb = ['采购意向', '招标预告', '招标公告', '招标答疑', '公告变更', '资审结果', '中标信息', '合同公告', '废标公告']
  691. id2label = {k: v for k, v in enumerate(lb)}
  692. label2id = {v: k for k, v in id2label.items()}
  693. # if os.path.exists(lb_path):
  694. # with open(lb_path, 'rb') as f:
  695. # id2label = pickle.load(f)
  696. # label2id = {v: k for k, v in id2label.items()}
  697. print(label2id)
  698. df_test = pd.read_excel('data/docchannel带数据源_bidi_process_0420日之前标注每数据源每类别抽取5篇数据.xlsx') # df_test_all.xlsx
  699. # df_test = pd.read_excel('data/docchannel带数据源2021-04-16_bidi_process_predict.xlsx') # df_test_all.xlsx
  700. # df_test = pd.read_excel('data/按数据源类别抽取重新标注数据_predict_类型预测.xlsx') # df_test_all.xlsx
  701. # df_test = pd.read_excel('data/df_test.xlsx') # df_test_all.xlsx
  702. # df_test = pd.read_excel('data/docchannel带数据源2021-04-12_bidi_process_predict2_2.xlsx.xlsx') # df_test_all.xlsx
  703. # df_test = pd.read_excel('/data/esa_sdk/text_classifier_2020_09_28/channel_data/docchannel带数据源_bidi_process.xlsx') # df_test_all.xlsx
  704. # l = []
  705. # for sour in set(df_test['web_source_no']):
  706. # df_tmp = df_test[df_test.loc[:, 'web_source_no']==sour]
  707. # if len(df_tmp)>5:
  708. # l.append(df_tmp.sample(5))
  709. # df_test = pd.DataFrame()
  710. # df_test = df_test.append(l, ignore_index=True)
  711. # df_test = df_test[df_test.loc[:, 'label'] != '招标文件']
  712. # df_test['label_old'] = df_test['label']
  713. df_test.dropna(subset=['segword'], inplace=True)
  714. df_test.reset_index(drop=True, inplace=True)
  715. df_test.fillna('', inplace=True)
  716. if 'relabel' in df_test.columns:
  717. df_test['relabel'] = df_test['relabel'].apply(lambda x: '招标答疑' if x == '招标补充' else x)
  718. df_test['relabel'] = df_test['relabel'].apply(lambda x: '新闻资讯' if x == '其他公告' else x)
  719. # df_test['label'] = df_test.apply(lambda x:x['relabel'] if x['relabel'] not in ['', 1, 0] else x['label'], axis=1)
  720. df_test['label'] = df_test.apply(lambda x:x['relabel'] if x['relabel'] in lb else x['label'], axis=1)
  721. df_test['label'] = df_test['label'].apply(lambda x:'新闻资讯' if x=='其他公告' else x)
  722. print('更新 label 完成')
  723. # assert set(df_test['label']) == set(label2id)
  724. # data_test, label_test = data_process(df_test, label2id=label2id)
  725. # data_test, label_test, title_test, df_test = data_process(df_test, label2id=label2id)
  726. data_test, label_test, title_test, df_test = data_process_sentence(df_test, label2id=label2id)
  727. batch_size = 128
  728. predicts = []
  729. alphas = []
  730. alpha_t = []
  731. max_porb = []
  732. # gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.45)
  733. # config = tf.ConfigProto(gpu_options=gpu_options)
  734. with tf.Session() as sess:
  735. saver = tf.train.import_meta_graph('model/channel_foolcut_title_lstm_content_att_concat0607_adam.ckpt.meta') # 0518
  736. saver.restore(sess, 'model/channel_foolcut_title_lstm_content_att_concat0607_adam.ckpt') # 0511 adadelta
  737. inputs = sess.graph.get_tensor_by_name('inputs/inputs:0')
  738. prob = sess.graph.get_tensor_by_name('inputs/dropout:0')
  739. labels = sess.graph.get_tensor_by_name('inputs/labels:0')
  740. title = sess.graph.get_tensor_by_name('inputs/title:0')
  741. logit = sess.graph.get_tensor_by_name('output/logit:0')
  742. softmax_output = sess.graph.get_tensor_by_name('output/softmax:0')
  743. alpha = sess.graph.get_tensor_by_name('net/alphas:0')
  744. # alpha = sess.graph.get_tensor_by_name('net/attention/alphas:0')
  745. # alpha_title = sess.graph.get_tensor_by_name('net/alphas_1:0')
  746. print(alpha)
  747. # print(alpha_title)
  748. for i in range(int((len(df_test) - 1) / batch_size) + 1):
  749. logit_,alpha_, softmax_output_= sess.run([logit, alpha, softmax_output], #,alpha_title alpha,
  750. feed_dict={inputs: data_test[i * batch_size:(i + 1) * batch_size],
  751. title: title_test[i * batch_size:(i + 1) * batch_size],
  752. labels: label_test[i * batch_size:(i + 1) * batch_size],
  753. prob: 1})
  754. predicts.extend(logit_) # logit_[0]
  755. alphas.extend(alpha_)
  756. max_porb.extend(np.max(softmax_output_, axis=-1))
  757. # alpha_t.extend(alpha_title_)
  758. assert len(predicts)==len(df_test)
  759. assert len(alphas) == len(df_test)
  760. pred_new = [id2label[id] for id in predicts]
  761. # df_test['pred_old'] = df_test['pred_new']
  762. # df_test['old=label'] = df_test['new=label']
  763. df_test['pred_new'] = pd.Series(pred_new)
  764. df_test['new=label'] = df_test.apply(lambda x: 1 if x['pred_new'] == x['label'] else 0, axis=1)
  765. # df_test['new=old'] = df_test.apply(lambda x: 1 if x['pred_new'] == x['pred_old'] else 0, axis=1)
  766. # df_test['pred_new'] = pd.Series(pred_new)
  767. # df_test['new=label'] = df_test.apply(lambda x:1 if x['pred_new']==x['label'] else 0, axis=1)
  768. keywords = []
  769. for i in range(len(alphas)):
  770. # words = df_test.loc[i, 'segword'].split()
  771. words = df_test.loc[i, 'content_input'].split()
  772. # words = [w for w in words if re.search('[\u4e00-\u9fa5]', w)]
  773. # words = (df_test.loc[i, 'segword']+df_test.loc[i, 'segword_title']).split()\
  774. # if isinstance(df_test.loc[i, 'segword_title'], str) and df_test.loc[i, 'segword_title'] not in \
  775. # df_test.loc[i, 'segword'] else df_test.loc[i, 'segword'].split()
  776. # words = [w for w in words if re.search('[\u4e00-\u9fa5]', w) and w in word_index]
  777. ids = np.argsort(-alphas[i])
  778. tmp_word = []
  779. for j in ids[:10]:
  780. if j < len(words):
  781. tmp_word.append(words[j])
  782. else:
  783. tmp_word.append('pad')
  784. keywords.append(tmp_word)
  785. df_test['keyword'] = pd.Series(keywords)
  786. # df_test['keyword_title'] = pd.Series(keyword_title)
  787. df_test['pred_prob'] = pd.Series(max_porb)
  788. df_test.sort_values(by=['new=label', 'label', 'pred_new'], inplace=True)
  789. print(df_test.head(5))
  790. # df_test['old=new'] = df_test.apply(lambda x:1 if x['pred_new']==x['pred'] else 0, axis=1)
  791. df_test.to_excel('data/docchannel带数据源_bidi_process_0420日之前标注每数据源每类别抽取5篇数据_predict.xlsx')
  792. # df_test.to_excel('data/docchannel带数据源0419_source_filter_bidi_process_predict.xlsx')
  793. # df_test.to_excel('data/按数据源类别抽取重新标注数据_predict_类型预测_predict.xlsx') #按数据源类别抽取重新标注数据_predict df_test_predict.xlsx
  794. # df_test.to_excel('data/docchannel带数据源2021-04-12_bidi_process_predict2_2.xlsx') # data/df_test_predict.xlsx
  795. # df_test.to_excel('/data/esa_sdk/text_classifier_2020_09_28/channel_data/docchannel带数据源_bidi_process_predict.xlsx',#'data/df_test_predict.xlsx',
  796. # columns=['docid', 'doctitle', 'dochtmlcon','relabel', 'label', 'new=label','pred_new',#'pred_new3', 'new=label3', 'pred_new2', 'new=label2',
  797. # 'pred_prob', 'keyword', 'segword', 'segword_title',
  798. # # 'sub_docs_json', 'dochtmlcon', 'docchannel', 'page_time', 'status','agency', 'tenderee', 'len(segword)'
  799. # ]) #
  800. get_acc_recall(df_test)
  801. def train_withoutEmb():
  802. lb = ['采购意向', '招标预告', '招标公告', '招标答疑', '公告变更', '资审结果', '中标信息', '合同公告', '废标公告']
  803. id2label = {k: v for k, v in enumerate(lb)}
  804. label2id = {v: k for k, v in id2label.items()}
  805. batch_size = 256
  806. # assert set(label2id)==set(df['label'])
  807. df1 = pd.read_excel('data/按数据源类别抽取重新标注数据_predict_类型预测.xlsx')
  808. df = pd.read_excel('data/公告类型标注数据2021-05-26_bidi_process_predict_类型预测.xlsx')
  809. # df1 = pd.read_excel('data/按数据源类别抽取重新标注数据_predict_类型预测_分开候选人公示.xlsx')
  810. # df = pd.read_excel('data/公告类型标注数据2021-05-26_bidi_process_predict_类型预测_分开候选人公示.xlsx')
  811. df = df.append(df1, ignore_index=True)
  812. # df = df[df.loc[:, 'relabel'].isin(lb)]
  813. df.drop_duplicates(subset=['segword'], inplace=True)
  814. df.reset_index(drop=True, inplace=True)
  815. if 'relabel' in df.columns:
  816. df['relabel'] = df['relabel'].apply(lambda x:'中标信息' if x=='候选人公示' else x)
  817. df['label'] = df.apply(lambda x:x['relabel'] if x['relabel'] not in ['', 1, 0] else x['label'], axis=1)
  818. df = df[df.loc[:, 'relabel'].isin(lb)]
  819. df.dropna(subset=['segword'], inplace=True)
  820. df_train , df_test = split_train_test(df, split_rate=0.10)
  821. df_train.reset_index(drop=True, inplace=True)
  822. df_test.reset_index(drop=True, inplace=True)
  823. df_train.to_excel('data/df_train.xlsx', columns=['segword', 'segword_title', 'label'])
  824. df_test.to_excel('data/df_test.xlsx')
  825. df_train = pd.read_excel('data/df_train.xlsx')
  826. # df_train = df_train.append(df, ignore_index=True)
  827. # df_train = df_train[:20000]
  828. df_train = df_train.sample(frac=1)
  829. df_test = pd.read_excel('data/df_test.xlsx')
  830. df_test = df_test.sample(frac=1)
  831. # assert set(df_train['label'])==set(label2id)
  832. # print(df_train.head(3))
  833. # data_train, label_train, title_train, df_train = data_process(df_train, label2id=label2id) # df_train
  834. # data_test, label_test, title_test, df_test = data_process(df_test, label2id=label2id) # df_test
  835. data_train, label_train, title_train, df_train = data_process_sentence(df_train, label2id=label2id) # df_train
  836. data_test, label_test, title_test, df_test = data_process_sentence(df_test, label2id=label2id) # df_test
  837. # print('data_tran.shape', data_train.shape, label_train.shape)
  838. print('word_index大小 :', len(word_index), ',' in word_index)
  839. file_num = int((len(data_train)-1)/(100*batch_size))+1
  840. print('file_num', file_num)
  841. for i in range(file_num):
  842. # print('写文件',i*100*batch_size,(i+1)*100*batch_size)
  843. with open('data/train_data_lift/data_train{}.pkl'.format(i), 'wb') as f:
  844. pickle.dump(data_train[i*100*batch_size:(i+1)*100*batch_size], f)
  845. with open('data/train_data_lift/title_train{}.pkl'.format(i), 'wb') as f:
  846. pickle.dump(title_train[i*100*batch_size:(i+1)*100*batch_size], f)
  847. with open('data/train_data_lift/label_train{}.pkl'.format(i), 'wb') as f:
  848. pickle.dump(label_train[i*100*batch_size:(i+1)*100*batch_size], f)
  849. import gc
  850. import time
  851. # del df_train
  852. # del df
  853. # del data_train
  854. # del label_train
  855. # del title_train
  856. del df_test
  857. print('清除内存', gc.collect())
  858. time.sleep(1)
  859. print('清除内存', gc.collect())
  860. # word_index, tokenizer, embedding_matrix = get_embedding()
  861. inputs, mask, labels, prob, logit, loss, train_op, _p, _r, _f1, alpha, title, mask_title,\
  862. softmax_output = lstm_att_model_withoutEmb(len(id2label))
  863. # gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.55)
  864. # config = tf.ConfigProto(gpu_options=gpu_options)
  865. # config = tf.ConfigProto(allow_soft_placement=True)
  866. # config.gpu_options.per_process_gpu_memory_fraction = 0.45
  867. # config.gpu_options.allow_growth = True
  868. min_loss = 10
  869. train_losses = []
  870. val_losses = []
  871. max_f1 = 0
  872. with tf.Session() as sess: # config=config
  873. sess.run(tf.global_variables_initializer())
  874. saver = tf.train.Saver()
  875. print(alpha)
  876. # saver.restore(sess, 'model/channel_foolcut_title_lstm_content_att_concat_withoutEmb0621_adam.ckpt')
  877. # saver.restore(sess, 'model/channel_foolcut_title_lstm_content_att_concat0607_adam.ckpt')
  878. for epoch in range(80):
  879. batch_loss = []
  880. batch_f1 = []
  881. # tensor_name_list = [tensor.name for tensor in tf.get_default_graph().as_graph_def().node]
  882. # print('当前节点数量',len(tensor_name_list))
  883. for i in range(file_num):
  884. with open('data/train_data_lift/data_train{}.pkl'.format(i), 'rb') as f:
  885. data_train = pickle.load(f)
  886. with open('data/train_data_lift/title_train{}.pkl'.format(i), 'rb') as f:
  887. title_train = pickle.load(f)
  888. with open('data/train_data_lift/label_train{}.pkl'.format(i), 'rb') as f:
  889. label_train = pickle.load(f)
  890. for i in range(int((len(data_train) - 1) / batch_size) + 1):
  891. _, loss_, logit_, p, r, f1 = sess.run([train_op, loss, logit, _p, _r, _f1],
  892. feed_dict={
  893. inputs:[[embedding_matrix[i] for i in l] for l in data_train[i * batch_size:(i + 1) * batch_size]],
  894. title: [[embedding_matrix[i] for i in l] for l in title_train[i * batch_size:(i + 1) * batch_size]],
  895. mask: 1-np.not_equal(data_train[i * batch_size:(i + 1) * batch_size],0),
  896. mask_title: 1-np.not_equal(title_train[i * batch_size:(i + 1) * batch_size],0),
  897. labels: label_train[i * batch_size:(i + 1) * batch_size],
  898. prob: 0.5}
  899. # feed_dict={
  900. # inputs: np.array(data_train[i * batch_size:(i + 1) * batch_size]),
  901. # title: np.array(title_train[i * batch_size:(i + 1) * batch_size]),
  902. # labels: label_train[i * batch_size:(i + 1) * batch_size],
  903. # prob: 0.5}
  904. )
  905. # print(loss_, p, r, f1)
  906. batch_f1.append(f1)
  907. batch_loss.append(loss_)
  908. print('训练 平均损失:%.4f, 平均f1:%.4f' % (np.mean(batch_loss), np.mean(batch_f1)))
  909. train_losses.append(np.mean(batch_loss))
  910. batch_loss = []
  911. batch_f1 = []
  912. for i in range(int((len(data_test) - 1) / batch_size) + 1):
  913. loss_, p, r, f1 = sess.run([loss, _p, _r, _f1],
  914. feed_dict={
  915. inputs: [[embedding_matrix[i] for i in l] for l in
  916. data_test[i * batch_size:(i + 1) * batch_size]],
  917. title: [[embedding_matrix[i] for i in l] for l in
  918. title_test[i * batch_size:(i + 1) * batch_size]],
  919. mask: 1-np.not_equal(data_test[i * batch_size:(i + 1) * batch_size], 0),
  920. mask_title: 1-np.not_equal(title_test[i * batch_size:(i + 1) * batch_size], 0),
  921. labels: label_test[i * batch_size:(i + 1) * batch_size],
  922. prob: 1}
  923. # feed_dict={inputs: np.array(data_test[i * batch_size:(i + 1) * batch_size]),
  924. # title: np.array(title_test[i * batch_size:(i + 1) * batch_size]),
  925. # labels: label_test[i * batch_size:(i + 1) * batch_size],
  926. # prob: 1}
  927. )
  928. # print('val_loss, p, r, f1:', loss_, p, r, f1)
  929. batch_f1.append(f1)
  930. batch_loss.append(loss_)
  931. print('第%d轮,val 平均损失:%.4f, 平均f1:%.4f' % (epoch, np.mean(batch_loss), np.mean(batch_f1)))
  932. val_losses.append(np.mean(batch_loss))
  933. if min_loss > np.mean(batch_loss): # max_f1<np.mean(batch_f1) and
  934. max_f1 = np.mean(batch_f1)
  935. min_loss = np.mean(batch_loss)
  936. saver.save(sess,
  937. 'model/channel_foolcut_title_lstm_content_att_concat_withoutEmb0621_adam.ckpt') # 0416 # channel_title+content_xavier_emb.ckpt channel_title+content
  938. print('第%d轮,loss:%.4f, f1:%.4f 模型保存成功! ' % (epoch, np.mean(batch_loss), np.mean(batch_f1))) # concat0521
  939. # channel_foolcut_title_lstm_content_att_concat0607_adadelta
  940. from matplotlib import pyplot
  941. with open('data/train_loss.pkl', 'wb') as f:
  942. pickle.dump(train_losses, f)
  943. with open('data/val_loss.pkl', 'wb') as f:
  944. pickle.dump(val_losses, f)
  945. def predict_withoutEmb():
  946. batch_size = 512
  947. lb_path = 'data/id2label.pkl'
  948. # lb = ['采购意向', '招标预告', '招标公告', '招标答疑', '公告变更', '资审结果', '招标补充', '中标信息', '合同公告', '废标公告']
  949. lb = ['采购意向', '招标预告', '招标公告', '招标答疑', '公告变更', '资审结果', '中标信息', '合同公告', '废标公告']
  950. id2label = {k: v for k, v in enumerate(lb)}
  951. label2id = {v: k for k, v in id2label.items()}
  952. # if os.path.exists(lb_path):
  953. # with open(lb_path, 'rb') as f:
  954. # id2label = pickle.load(f)
  955. # label2id = {v: k for k, v in id2label.items()}
  956. print(label2id)
  957. # df_test = pd.read_excel('data/docchannel带数据源_bidi_process_0420日之前标注每数据源每类别抽取5篇数据_predict.xlsx') # df_test_all.xlsx
  958. # df_test = pd.read_excel('data/docchannel带数据源2021-04-16_bidi_process_predict.xlsx') # df_test_all.xlsx
  959. # df_test = pd.read_excel('data/按数据源类别抽取重新标注数据_predict_类型预测.xlsx') # df_test_all.xlsx
  960. # df_test = pd.read_excel('data/df_test.xlsx') # df_test_all.xlsx
  961. df_test = pd.read_excel('data/docchannel带数据源2021-04-12-13-15-16预测错误数据源.xlsx') # df_test_all.xlsx
  962. # df_test = pd.read_excel('data/docchannel带数据源2021-04-12_bidi_process_predict2_2.xlsx.xlsx') # df_test_all.xlsx
  963. # df_test = pd.read_excel('/data/esa_sdk/text_classifier_2020_09_28/channel_data/docchannel带数据源_bidi_process.xlsx') # df_test_all.xlsx
  964. # l = []
  965. # for sour in set(df_test['web_source_no']):
  966. # df_tmp = df_test[df_test.loc[:, 'web_source_no']==sour]
  967. # if len(df_tmp)>5:
  968. # l.append(df_tmp.sample(5))
  969. # df_test = pd.DataFrame()
  970. # df_test = df_test.append(l, ignore_index=True)
  971. # df_test = df_test[df_test.loc[:, 'label'] != '招标文件']
  972. # df_test['label_old'] = df_test['label']
  973. df_test.dropna(subset=['segword'], inplace=True)
  974. df_test.reset_index(drop=True, inplace=True)
  975. df_test.fillna('', inplace=True)
  976. if 'relabel' in df_test.columns:
  977. df_test['relabel'] = df_test['relabel'].apply(lambda x: '招标答疑' if x == '招标补充' else x)
  978. df_test['relabel'] = df_test['relabel'].apply(lambda x: '新闻资讯' if x == '其他公告' else x)
  979. # df_test['label'] = df_test.apply(lambda x:x['relabel'] if x['relabel'] not in ['', 1, 0] else x['label'], axis=1)
  980. df_test['label'] = df_test.apply(lambda x:x['relabel'] if x['relabel'] in lb else x['label'], axis=1)
  981. df_test['label'] = df_test['label'].apply(lambda x:'新闻资讯' if x=='其他公告' else x)
  982. print('更新 label 完成')
  983. # assert set(df_test['label']) == set(label2id)
  984. # data_test, label_test = data_process(df_test, label2id=label2id)
  985. # data_test, label_test, title_test, df_test = data_process(df_test, label2id=label2id)
  986. data_test, label_test, title_test, df_test = data_process_sentence(df_test, label2id=label2id)
  987. batch_size = 128
  988. predicts = []
  989. alphas = []
  990. alpha_t = []
  991. max_porb = []
  992. # gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.45)
  993. # config = tf.ConfigProto(gpu_options=gpu_options)
  994. with tf.Session() as sess:
  995. # saver = tf.train.import_meta_graph('model/channel_foolcut_title_lstm_content_att_concat0607_adam.ckpt.meta') # 0518
  996. # saver.restore(sess, 'model/channel_foolcut_title_lstm_content_att_concat0607_adam.ckpt') # 0511 adadelta
  997. saver = tf.train.import_meta_graph('model/channel_foolcut_title_lstm_content_att_concat_withoutEmb0621_adam.ckpt.meta') # 0518
  998. saver.restore(sess, 'model/channel_foolcut_title_lstm_content_att_concat_withoutEmb0621_adam.ckpt') # 0511 adadelta
  999. inputs = sess.graph.get_tensor_by_name('inputs/inputs:0')
  1000. mask = sess.graph.get_tensor_by_name('inputs/mask:0')
  1001. mask_title = sess.graph.get_tensor_by_name('inputs/mask_title:0')
  1002. prob = sess.graph.get_tensor_by_name('inputs/dropout:0')
  1003. labels = sess.graph.get_tensor_by_name('inputs/labels:0')
  1004. title = sess.graph.get_tensor_by_name('inputs/title:0')
  1005. logit = sess.graph.get_tensor_by_name('output/logit:0')
  1006. softmax_output = sess.graph.get_tensor_by_name('output/softmax:0')
  1007. alpha = sess.graph.get_tensor_by_name('net/alphas:0')
  1008. # alpha = sess.graph.get_tensor_by_name('net/attention/alphas:0')
  1009. # alpha_title = sess.graph.get_tensor_by_name('net/alphas_1:0')
  1010. print(alpha)
  1011. # print(alpha_title)
  1012. for i in range(int((len(df_test) - 1) / batch_size) + 1):
  1013. logit_,alpha_, softmax_output_= sess.run([logit, alpha, softmax_output], #,alpha_title alpha,
  1014. feed_dict={
  1015. inputs: [[embedding_matrix[i] for i in l] for l in
  1016. data_test[i * batch_size:(i + 1) * batch_size]],
  1017. title: [[embedding_matrix[i] for i in l] for l in
  1018. title_test[i * batch_size:(i + 1) * batch_size]],
  1019. mask: 1 - np.not_equal(data_test[i * batch_size:(i + 1) * batch_size],
  1020. 0),
  1021. mask_title: 1 - np.not_equal(
  1022. title_test[i * batch_size:(i + 1) * batch_size], 0),
  1023. labels: label_test[i * batch_size:(i + 1) * batch_size],
  1024. prob: 1})
  1025. # feed_dict={inputs: data_test[i * batch_size:(i + 1) * batch_size],
  1026. # title: title_test[i * batch_size:(i + 1) * batch_size],
  1027. # labels: label_test[i * batch_size:(i + 1) * batch_size],
  1028. # prob: 1})
  1029. predicts.extend(logit_) # logit_[0]
  1030. alphas.extend(alpha_)
  1031. max_porb.extend(np.max(softmax_output_, axis=-1))
  1032. # alpha_t.extend(alpha_title_)
  1033. assert len(predicts)==len(df_test)
  1034. assert len(alphas) == len(df_test)
  1035. pred_new = [id2label[id] for id in predicts]
  1036. # df_test['pred_old'] = df_test['pred_new']
  1037. # df_test['old=label'] = df_test['new=label']
  1038. df_test['pred_new'] = pd.Series(pred_new)
  1039. df_test['new=label'] = df_test.apply(lambda x: 1 if x['pred_new'] == x['label'] else 0, axis=1)
  1040. # df_test['new=old'] = df_test.apply(lambda x: 1 if x['pred_new'] == x['pred_old'] else 0, axis=1)
  1041. # df_test['pred_new'] = pd.Series(pred_new)
  1042. # df_test['new=label'] = df_test.apply(lambda x:1 if x['pred_new']==x['label'] else 0, axis=1)
  1043. keywords = []
  1044. for i in range(len(alphas)):
  1045. # words = df_test.loc[i, 'segword'].split()
  1046. words = df_test.loc[i, 'content_input'].split()
  1047. # words = [w for w in words if re.search('[\u4e00-\u9fa5]', w)]
  1048. # words = (df_test.loc[i, 'segword']+df_test.loc[i, 'segword_title']).split()\
  1049. # if isinstance(df_test.loc[i, 'segword_title'], str) and df_test.loc[i, 'segword_title'] not in \
  1050. # df_test.loc[i, 'segword'] else df_test.loc[i, 'segword'].split()
  1051. # words = [w for w in words if re.search('[\u4e00-\u9fa5]', w) and w in word_index]
  1052. ids = np.argsort(-alphas[i])
  1053. tmp_word = []
  1054. for j in ids[:10]:
  1055. if j < len(words):
  1056. tmp_word.append(words[j])
  1057. else:
  1058. tmp_word.append('pad')
  1059. keywords.append(tmp_word)
  1060. df_test['keyword'] = pd.Series(keywords)
  1061. # df_test['keyword_title'] = pd.Series(keyword_title)
  1062. df_test['pred_prob'] = pd.Series(max_porb)
  1063. df_test.sort_values(by=['new=label', 'label', 'pred_new'], inplace=True)
  1064. print(df_test.head(5))
  1065. # df_test.to_excel('data/df_test_predict.xlsx')
  1066. df_test.to_excel('data/docchannel带数据源2021-04-12-13-15-16预测错误数据源_predict.xlsx')
  1067. # df_test['old=new'] = df_test.apply(lambda x:1 if x['pred_new']==x['pred'] else 0, axis=1)
  1068. # df_test.to_excel('data/docchannel带数据源_bidi_process_0420日之前标注每数据源每类别抽取5篇数据_predict.xlsx')
  1069. # df_test.to_excel('data/docchannel带数据源0419_source_filter_bidi_process_predict.xlsx')
  1070. # df_test.to_excel('data/按数据源类别抽取重新标注数据_predict_类型预测_predict.xlsx') #按数据源类别抽取重新标注数据_predict df_test_predict.xlsx
  1071. # df_test.to_excel('data/docchannel带数据源2021-04-12_bidi_process_predict2_2.xlsx') # data/df_test_predict.xlsx
  1072. # df_test.to_excel('/data/esa_sdk/text_classifier_2020_09_28/channel_data/docchannel带数据源_bidi_process_predict.xlsx',#'data/df_test_predict.xlsx',
  1073. # columns=['docid', 'doctitle', 'dochtmlcon','relabel', 'label', 'new=label','pred_new',#'pred_new3', 'new=label3', 'pred_new2', 'new=label2',
  1074. # 'pred_prob', 'keyword', 'segword', 'segword_title',
  1075. # # 'sub_docs_json', 'dochtmlcon', 'docchannel', 'page_time', 'status','agency', 'tenderee', 'len(segword)'
  1076. # ]) #
  1077. get_acc_recall(df_test)
  1078. def get_acc_recall(df):
  1079. # df.reset_index(drop=True, inplace=True)
  1080. df.fillna('', inplace=True)
  1081. # df['label'] = df.apply(lambda x: x['relabel'] if x['relabel'] else x['label'], axis=1)
  1082. lab_dic = {}
  1083. for lb in set(df['label']):
  1084. df_tmp = df[df.loc[:, 'label'] == lb]
  1085. lab_dic[lb] = set(df_tmp['docid'])
  1086. pre_dic = {}
  1087. for lb in set(df['pred_new']):
  1088. df_tmp = df[df.loc[:, 'pred_new'] == lb]
  1089. pre_dic[lb] = set(df_tmp['docid'])
  1090. eq_total = lab_total = pre_total = 0
  1091. for lb in sorted(pre_dic):
  1092. if lb in lab_dic:
  1093. eq = len(pre_dic[lb]&lab_dic[lb])
  1094. lab = len(lab_dic[lb])
  1095. pre = len(pre_dic[lb])
  1096. recall = eq/lab if lab>0 else 0
  1097. acc = eq/pre if pre>0 else 0
  1098. print('类别:%s ;召回率:%.4f;准确率:%.4f'%(lb, recall, acc))
  1099. eq_total += eq
  1100. lab_total += lab
  1101. pre_total += pre
  1102. rc_total = eq_total/lab_total if lab_total>0 else 0
  1103. acc_total = eq_total/pre_total if eq_total>0 else 0
  1104. print('准确率:%.4f, 召回率:%.4f, F1: %.4f'%(acc_total, rc_total, 2*(rc_total*acc_total)/(rc_total+acc_total)))
  1105. class DocChannel():
  1106. def __init__(self, life_model='model/channel.pb', type_model='model/doctype.pb'):
  1107. self.lift_sess, self.lift_title, self.lift_content, self.lift_prob, self.lift_softmax,\
  1108. self.mask, self.mask_title = self.load_life(life_model)
  1109. self.type_sess, self.type_title, self.type_content, self.type_prob, self.type_softmax,\
  1110. self.type_mask, self.type_mask_title = self.load_type(type_model)
  1111. lb_type = ['采招数据', '土地矿产', '拍卖出让', '产权交易', '新闻资讯']
  1112. lb_life = ['采购意向', '招标预告', '招标公告', '招标答疑', '公告变更', '资审结果', '中标信息', '合同公告', '废标公告']
  1113. self.id2type = {k: v for k, v in enumerate(lb_type)}
  1114. self.id2life = {k: v for k, v in enumerate(lb_life)}
  1115. def load_life(self,life_model):
  1116. # sess = tf.Session()
  1117. # saver = tf.train.import_meta_graph('model/channel_foolcut_title_lstm_content_att_concat0607_adam.ckpt.meta') # 0518
  1118. # saver.restore(sess, 'model/channel_foolcut_title_lstm_content_att_concat0607_adam.ckpt')
  1119. # inputs = sess.graph.get_tensor_by_name('inputs/inputs:0')
  1120. # prob = sess.graph.get_tensor_by_name('inputs/dropout:0')
  1121. # title = sess.graph.get_tensor_by_name('inputs/title:0')
  1122. # # logit = sess.graph.get_tensor_by_name('output/logit:0')
  1123. # softmax = sess.graph.get_tensor_by_name('output/softmax:0')
  1124. # return sess, title, inputs, prob, softmax
  1125. with tf.Graph().as_default() as graph:
  1126. output_graph_def = graph.as_graph_def()
  1127. with open(life_model, 'rb') as f:
  1128. output_graph_def.ParseFromString(f.read())
  1129. tf.import_graph_def(output_graph_def, name='')
  1130. print("%d ops in the final graph" % len(output_graph_def.node))
  1131. del output_graph_def
  1132. sess = tf.Session(graph=graph)
  1133. sess.run(tf.global_variables_initializer())
  1134. inputs = sess.graph.get_tensor_by_name('inputs/inputs:0')
  1135. prob = sess.graph.get_tensor_by_name('inputs/dropout:0')
  1136. title = sess.graph.get_tensor_by_name('inputs/title:0')
  1137. mask = sess.graph.get_tensor_by_name('inputs/mask:0')
  1138. mask_title = sess.graph.get_tensor_by_name('inputs/mask_title:0')
  1139. # logit = sess.graph.get_tensor_by_name('output/logit:0')
  1140. softmax = sess.graph.get_tensor_by_name('output/softmax:0')
  1141. return sess, title, inputs, prob, softmax, mask, mask_title
  1142. def load_type(self,type_model):
  1143. with tf.Graph().as_default() as graph:
  1144. output_graph_def = graph.as_graph_def()
  1145. with open(type_model, 'rb') as f:
  1146. output_graph_def.ParseFromString(f.read())
  1147. tf.import_graph_def(output_graph_def, name='')
  1148. print("%d ops in the final graph" % len(output_graph_def.node))
  1149. del output_graph_def
  1150. sess = tf.Session(graph=graph)
  1151. sess.run(tf.global_variables_initializer())
  1152. inputs = sess.graph.get_tensor_by_name('inputs/inputs:0')
  1153. prob = sess.graph.get_tensor_by_name('inputs/dropout:0')
  1154. title = sess.graph.get_tensor_by_name('inputs/title:0')
  1155. mask = sess.graph.get_tensor_by_name('inputs/mask:0')
  1156. mask_title = sess.graph.get_tensor_by_name('inputs/mask_title:0')
  1157. # logit = sess.graph.get_tensor_by_name('output/logit:0')
  1158. softmax = sess.graph.get_tensor_by_name('output/softmax:0')
  1159. return sess, title, inputs, prob, softmax, mask, mask_title
  1160. def predict_process(self, docid='', doctitle='', dochtmlcon=''):
  1161. def get_kw_senten(s, span=10):
  1162. doc_sens = []
  1163. tmp = 0
  1164. num = 0
  1165. end_idx = 0
  1166. for it in re.finditer(kws, s): # '|'.join(keywordset)
  1167. left = s[end_idx:it.end()].split()
  1168. right = s[it.end():].split()
  1169. tmp_seg = s[tmp:it.start()].split()
  1170. if len(tmp_seg) > span or tmp == 0:
  1171. doc_sens.append(' '.join(left[-span:] + right[:span]))
  1172. end_idx = it.end() + 1 + len(' '.join(right[:span]))
  1173. tmp = it.end()
  1174. num += 1
  1175. if num >= sentence_num:
  1176. break
  1177. if doc_sens == []:
  1178. doc_sens.append(s)
  1179. return doc_sens
  1180. def word2id(wordlist, max_len=sequen_len):
  1181. ids = [word_index.get(w, 0) for w in wordlist]
  1182. ids = ids[:max_len] if len(ids) >= max_len else ids + [0] * (max_len - len(ids))
  1183. assert len(ids) == max_len
  1184. return ids
  1185. import fool
  1186. cost_time = dict()
  1187. datas = []
  1188. datas_title = []
  1189. articles = [[docid, dochtmlcon, '', '', doctitle]]
  1190. try:
  1191. # list_articles = Preprocessing.get_preprocessed_article(articles, cost_time)
  1192. # list_sentences = Preprocessing.get_preprocessed_sentences(list_articles, True, cost_time)
  1193. # sen_words = [sen.tokens for sen in list_sentences[0]]
  1194. # words = [it for sen in sen_words for it in sen]
  1195. # segword_content = ' '.join(words)
  1196. segword_content = dochtmlcon
  1197. segword_title = ' '.join(fool.cut(doctitle)[0])
  1198. except:
  1199. segword_content = ''
  1200. segword_title = ''
  1201. segword_title = ' '.join([it for it in segword_title.split() if it.isalpha() and it in vocab][:title_len])
  1202. segword_content = ' '.join([it for it in segword_content.split() if it.isalpha() and it in vocab][:2000])
  1203. segword_content = segword_content.replace(' 中 选 ', ' 中选 ').replace(' 中 标 ', ' 中标 ').replace(' 补 遗 ', ' 补遗 '). \
  1204. replace(' 更 多', '').replace(' 更多', '').replace(' 中 号 ', ' 中标 ').replace(' 中 选人 ', ' 中选人 '). \
  1205. replace(' 点击 下载 查看', '').replace(' 咨询 报价 请 点击', '').replace('终结', '终止')
  1206. doc_word_list = segword_content.split()
  1207. if len(doc_word_list) > sequen_len / 2:
  1208. doc_sens = get_kw_senten(' '.join(doc_word_list[100:500]))
  1209. doc_sens = ' '.join(doc_word_list[:100]) + '\n' + '\n'.join(doc_sens)
  1210. else:
  1211. doc_sens = ' '.join(doc_word_list[:sequen_len])
  1212. datas.append(word2id(doc_sens.split(), max_len=sequen_len))
  1213. datas_title.append(word2id(segword_title.split(), max_len=title_len))
  1214. return datas, datas_title
  1215. def predict(self, title, content):
  1216. # print('准备预测')
  1217. data_content, data_title = self.predict_process(docid='', doctitle=title, dochtmlcon=content)
  1218. pred = self.type_sess.run(self.type_softmax,
  1219. feed_dict={self.type_title:[[embedding_matrix[i] for i in l] for l in data_title],
  1220. self.type_content:[[embedding_matrix[i] for i in l] for l in data_content],
  1221. self.type_mask:1 - np.not_equal(data_content, 0),
  1222. self.type_mask_title:1 - np.not_equal(data_title, 0),
  1223. self.type_prob:1}
  1224. )
  1225. id = np.argmax(pred, axis=1)[0]
  1226. prob = pred[0][id]
  1227. if id != 4:
  1228. pred = self.lift_sess.run(self.lift_softmax,
  1229. feed_dict={self.lift_title:[[embedding_matrix[i] for i in l] for l in data_title],
  1230. self.lift_content:[[embedding_matrix[i] for i in l] for l in data_content],
  1231. self.mask:1 - np.not_equal(data_content, 0),
  1232. self.mask_title:1 - np.not_equal(data_title, 0),
  1233. self.lift_prob:1}
  1234. )
  1235. id = np.argmax(pred, axis=1)[0]
  1236. prob = pred[0][id]
  1237. return self.id2life[id], prob
  1238. else:
  1239. return self.id2type[id], prob
  1240. def save_pb():
  1241. from tensorflow import graph_util
  1242. saver = tf.train.import_meta_graph('model/channel_foolcut_title_lstm_content_att_concat_withoutEmb0621_adam.ckpt.meta')
  1243. graph = tf.get_default_graph()
  1244. graph_def = graph.as_graph_def()
  1245. with tf.Session() as sess:
  1246. saver.restore(sess, 'model/channel_foolcut_title_lstm_content_att_concat_withoutEmb0621_adam.ckpt') #0608
  1247. output_graph_def = graph_util.convert_variables_to_constants(sess,
  1248. input_graph_def=graph_def,
  1249. output_node_names=['inputs/inputs',
  1250. 'inputs/dropout',
  1251. 'inputs/title',
  1252. 'inputs/mask',
  1253. 'inputs/mask_title',
  1254. # 'output/logit',
  1255. 'output/softmax'])
  1256. # 'inputs/labels',
  1257. # 'net/alphas'])
  1258. with tf.gfile.GFile('model/channel.pb', 'wb') as f:
  1259. f.write(output_graph_def.SerializeToString())
  1260. print("%d ops in the final graph" % len(output_graph_def.node))
  1261. def predict_pb():
  1262. batch_size = 512
  1263. # lb_path = 'data/id2label.pkl'
  1264. # if os.path.exists(lb_path):
  1265. # with open(lb_path, 'rb') as f:
  1266. # id2label = pickle.load(f)
  1267. # label2id = {v: k for k, v in id2label.items()}
  1268. lb = ['采购意向', '招标预告', '招标公告', '招标答疑', '公告变更', '资审结果', '中标信息', '合同公告', '废标公告']
  1269. id2label = {k: v for k, v in enumerate(lb)}
  1270. label2id = {v: k for k, v in id2label.items()}
  1271. print(label2id)
  1272. df_test = pd.read_excel('data/df_test.xlsx') # df_test_all.xlsx
  1273. df_test = df_test[df_test.loc[:, 'label'] != '招标文件']
  1274. df_test.dropna(subset=['segword'], inplace=True)
  1275. df_test.reset_index(drop=True, inplace=True)
  1276. df_test.fillna('', inplace=True)
  1277. if 'relabel' in df_test.columns:
  1278. df_test['relabel'] = df_test['relabel'].apply(lambda x: '新闻资讯' if x == '其他公告' else x)
  1279. df_test['label'] = df_test.apply(lambda x: x['relabel'] if x['relabel'] not in ['', 1] else x['label'], axis=1)
  1280. df_test['label'] = df_test['label'].apply(lambda x: '新闻资讯' if x == '其他公告' else x)
  1281. print('更新 label 完成')
  1282. # assert set(df_test['label']) == set(label2id)
  1283. # data_test, label_test = data_process(df_test, label2id=label2id)
  1284. data_test, label_test, title_test, df_test = data_process(df_test, label2id=label2id)
  1285. batch_size = 128
  1286. predicts = []
  1287. alphas = []
  1288. alpha_t = []
  1289. max_porb = []
  1290. import gc
  1291. with tf.Graph().as_default() as graph:
  1292. output_graph_def = graph.as_graph_def()
  1293. with open('model/channel.pb', 'rb') as f:
  1294. output_graph_def.ParseFromString(f.read())
  1295. tf.import_graph_def(output_graph_def, name='')
  1296. print("%d ops in the final graph" % len(output_graph_def.node))
  1297. del output_graph_def
  1298. print('清理内存 ',gc.collect())
  1299. with tf.Session(graph=graph) as sess:
  1300. sess.run(tf.global_variables_initializer())
  1301. inputs = sess.graph.get_tensor_by_name('inputs/inputs:0')
  1302. prob = sess.graph.get_tensor_by_name('inputs/dropout:0')
  1303. title = sess.graph.get_tensor_by_name('inputs/title:0')
  1304. logit = sess.graph.get_tensor_by_name('output/logit:0')
  1305. # labels = sess.graph.get_tensor_by_name('inputs/labels:0')
  1306. # softmax_output = sess.graph.get_tensor_by_name('output/softmax:0')
  1307. # alpha = sess.graph.get_tensor_by_name('net/alphas:0')
  1308. print('data_test.shape:',data_test.shape)
  1309. print(logit)
  1310. print(title)
  1311. # for i in range(int((len(df_test) - 1) / batch_size) + 1):
  1312. # logit_, alpha_, softmax_output_ = sess.run([logit, alpha, softmax_output], # ,alpha_title
  1313. # feed_dict={
  1314. # inputs: data_test[i * batch_size:(i + 1) * batch_size],
  1315. # title: title_test[i * batch_size:(i + 1) * batch_size],
  1316. # labels: label_test[i * batch_size:(i + 1) * batch_size],
  1317. # prob: 1})
  1318. for i in range(int((len(df_test) - 1) / batch_size) + 1):
  1319. # print("%d ops in the final graph" % len(output_graph_def.node))
  1320. logit_ = sess.run(logit, # ,alpha_title
  1321. feed_dict={
  1322. inputs: data_test[i * batch_size:(i + 1) * batch_size],
  1323. title: title_test[i * batch_size:(i + 1) * batch_size],
  1324. prob: 1})
  1325. predicts.extend(logit_) # logit_[0]
  1326. # alphas.extend(alpha_)
  1327. # max_porb.extend(np.max(softmax_output_, axis=-1))
  1328. # alpha_t.extend(alpha_title_)
  1329. # assert len(predicts) == len(df_test)
  1330. # assert len(alphas) == len(df_test)
  1331. pred_new = [id2label[id] for id in predicts]
  1332. df_test['pred_new'] = pd.Series(pred_new)
  1333. print(pred_new[:10])
  1334. if __name__ == "__main__":
  1335. # import glob
  1336. # for num in [12, 13, 14, 15, 16]:
  1337. # df = pd.DataFrame()
  1338. # df_l = []
  1339. # for file in glob.glob('data/docchannel带数据源2021-04-{}_bidi_process_predict*'.format(num)):
  1340. # df_tmp = pd.read_excel(file)
  1341. # df_l.append(df_tmp)
  1342. # df = df.append(df_l, ignore_index=True)
  1343. # # df = pd.read_excel('G:/公告docchannel分类数据/docchannel带数据源2021-04-12_bidi_process.xlsx')
  1344. # df.drop_duplicates(subset=['segword'], inplace=True)
  1345. # print(len(df))
  1346. #
  1347. # l = []
  1348. # for sour in set(df['web_source_no']):
  1349. # df_sour = df[df.loc[:, 'web_source_no'] == sour]
  1350. # for lb in set(df_sour['label']):
  1351. # df_lb = df_sour[df_sour.loc[:, 'label'] == lb]
  1352. # if len(df_lb) > 5:
  1353. # l.append(df_lb.sample(5))
  1354. # else:
  1355. # l.append(df_lb)
  1356. # df_2 = pd.DataFrame()
  1357. # df_2 = df_2.append(l, ignore_index=True)
  1358. # print('过滤后数量:', len(df_2))
  1359. # df_2.reset_index(drop=True, inplace=True)
  1360. # df_2.to_excel('data/docchannel带数据源2021-04-{}_bidi_process_predict_filter.xlsx'.format(num))
  1361. # import glob
  1362. # df = pd.DataFrame()
  1363. # df_l = []
  1364. # for num in [12, 13, 14, 15, 16]:
  1365. # for file in glob.glob('data/docchannel带数据源2021-04-{}_bidi_process_predict_filter*'.format(num)):
  1366. # df_tmp = pd.read_excel(file)
  1367. # df_l.append(df_tmp)
  1368. # df = df.append(df_l, ignore_index=True)
  1369. # df.drop_duplicates(subset=['segword'], inplace=True)
  1370. # df.sort_values(by=['web_source_no', 'label'], inplace=True)
  1371. # df.reset_index(drop=True, inplace=True)
  1372. # num = int(len(df)/4)+2
  1373. # for i in range(4):
  1374. # df_t = df[i*num:(i+1)*num]
  1375. # df_t.to_excel('data/docchannel带数据源2021-04-12-16抽取数据_{}.xlsx'.format(i))
  1376. # cut_words()
  1377. # import datetime
  1378. # import os
  1379. # in_date = '2021-04-11' # '2018-01-05'
  1380. # dt = datetime.datetime.strptime(in_date, "%Y-%m-%d")
  1381. # cut_words('2021-04-23_全国_数据导出1')
  1382. # for i in range(2, 6, 1): # 100, 800, 9
  1383. # date = (dt + datetime.timedelta(days=i)).strftime('%Y-%m-%d')
  1384. # filename = 'docchannel带数据源{}'.format(date)
  1385. # print(filename)
  1386. # if os.path.exists('data/'+filename+'.xlsx'):
  1387. # print('准备分词')
  1388. # cut_words(filename)
  1389. print('准备进入train')
  1390. # train()
  1391. # train_withoutEmb()
  1392. # predict_withoutEmb()
  1393. print('训练完成')
  1394. # predict()
  1395. # cut_words('公告类型标注数据2021-05-26')
  1396. save_pb()
  1397. # lb = ['采招数据', '土地矿产', '拍卖出让', '产权交易', '新闻资讯']
  1398. # id2label = {k: v for k, v in enumerate(lb)}
  1399. # label2id = {v: k for k, v in id2label.items()}
  1400. # lb = ['采购意向', '招标预告', '招标公告', '招标答疑', '公告变更', '资审结果', '中标信息', '合同公告', '废标公告']
  1401. # id2label = {k: v for k, v in enumerate(lb)}
  1402. # label2id = {v: k for k, v in id2label.items()}
  1403. # import numpy as np
  1404. # DocChannel = DocChannel()
  1405. # print(DocChannel.lift_softmax)
  1406. #
  1407. # # df_test = pd.read_excel('data/df_test.xlsx')
  1408. # df_test = pd.read_excel('data/df_test_公告类型.xlsx')
  1409. # i = 6
  1410. # for i in range(len(df_test)):
  1411. # title = df_test.loc[i, 'doctitle']
  1412. # # content = df_test.loc[i, 'dochtmlcon']
  1413. # content = df_test.loc[i, 'segword']
  1414. # pred, prob = DocChannel.predict(title, content)
  1415. # print('预测类别:%s, 阈值:%.4f, 标注类别:%s'
  1416. # %(pred, prob, df_test.loc[i, 'label']))
  1417. # lb_id = np.argmax(pred,axis=1)
  1418. # print(pred)
  1419. # print('预测类别:%s, 阈值:%.4f, 标注类别:%s'
  1420. # %(id2label.get(lb_id[0], 'unknow'), pred[0][lb_id[0]], df_test.loc[i, 'label']))
  1421. # print('预测完毕!')
  1422. # rs = np.argmax(pred, axis=-1)
  1423. # print(pred)
  1424. # print( rs)
  1425. # for i, p in zip(rs, pred):
  1426. # print(p[i])
  1427. # import gc
  1428. # del vocab
  1429. # del embedding_matrix
  1430. # print('清理内存 ', gc.collect())
  1431. # predict_pb()
  1432. # lb_path = 'data/id2label.pkl'
  1433. # if os.path.exists(lb_path):
  1434. # with open(lb_path, 'rb') as f:
  1435. # id2label = pickle.load(f)
  1436. # label2id = {v: k for k, v in id2label.items()}
  1437. # df_test = pd.read_excel('data/df_test_predict.xlsx')
  1438. # data_test, label_test, title_test, df_test = data_process(df_test, label2id=label2id)
  1439. # df_test.to_excel('data/df_test_predict.xlsx')
  1440. # from collections import Counter
  1441. # df_train = pd.read_excel('data/df_train.xlsx')
  1442. # df_test = pd.read_excel('data/df_test_predict.xlsx')
  1443. # c1 = Counter(df_train['label'])
  1444. # c3 = Counter(df_test['pred_new'])
  1445. # c2 = Counter(df_test['label'])
  1446. # print(c1)
  1447. # print(c2)
  1448. # print(c3)
  1449. # print(set(c1)-set(c2))
  1450. # print(set(c2)-set(c1))
  1451. # split_words = []
  1452. # df = pd.read_excel(
  1453. # '/data/esa_sdk/text_classifier_2020_09_28/channel_data/docchannel带数据源0413_filter_bidi_process.xlsx')
  1454. # for text in df['segword']:
  1455. # w2 = re.findall(' (\w \w) ', text)
  1456. # w3 = re.findall(' (\w \w \w) ', text)
  1457. # if w2:
  1458. # split_words.append(w2)
  1459. # if w3:
  1460. # split_words.append(w3)
  1461. # from collections import Counter
  1462. # c = Counter([w for l in split_words for w in l])
  1463. # m = c.most_common()
  1464. # print(m[20:100])
  1465. # print()