IterateModeling_LR.py 23 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566
  1. #coding:utf8
  2. import sys
  3. import os
  4. import glob
  5. sys.path.append(os.path.abspath("../.."))
  6. import psycopg2
  7. from keras import models
  8. from keras import layers
  9. from keras import optimizers,losses,metrics
  10. from keras.callbacks import ModelCheckpoint
  11. import codecs
  12. import copy
  13. from BiddingKG.dl.common.Utils import *
  14. import pandas as pd
  15. sourcetable = "label_guest_role"
  16. domain = sourcetable.split("_")[2]
  17. model_file = "model_"+domain+".model"
  18. input_shape = (2,10,128)
  19. output_shape = [6]
  20. def getTokensLabels(t,isTrain=True,predict=False):
  21. '''
  22. @param:
  23. t:标注数据所在表
  24. isTrain:是否训练
  25. predict:是否是验证
  26. @return:返回标注数据的处理后的输入和标签
  27. '''
  28. conn = psycopg2.connect(dbname="BiddingKM_test_10000",user="postgres",password="postgres",host="192.168.2.101")
  29. cursor = conn.cursor()
  30. if predict:
  31. sql = '''
  32. select A.tokens,B.begin_index,B.end_index,0,B.entity_id from sentences A,entity_mention_copy B where B.entity_type in ('org','company') and A.doc_id=B.doc_id and A.sentence_index=B.sentence_index
  33. and B.doc_id in (select doc_id from articles_validation ) order by B.doc_id
  34. '''
  35. else:
  36. select_sql = " select A.tokens,B.begin_index,B.end_index,C.label,C.entity_id "
  37. '''
  38. if isTrain:
  39. train_sql = " and C.id not in(select variable_id from dd_graph_variables_holdout) "
  40. else:
  41. train_sql = " and C.id in(select variable_id from dd_graph_variables_holdout)"
  42. '''
  43. '''
  44. if isTrain:
  45. train_sql = " and A.doc_id not in(select id from articles_processed order by id limit 1000) "
  46. else:
  47. train_sql = " and A.doc_id in(select id from articles_processed order by id limit 1000)"
  48. '''
  49. if isTrain:
  50. train_sql = " and C.entity_id not in(select entity_id from is_wintenderer_label_inference where id in(select variable_id from dd_graph_variables_holdout))"
  51. else:
  52. #train_sql = " and C.entity_id in(select entity_id from is_wintenderer_label_inference where id in(select variable_id from dd_graph_variables_holdout))"
  53. train_sql = " and exists(select 1 from test_predict_money h,entity_mention g where h.entity_id=g.entity_id and A.doc_id=g.doc_id) order by B.doc_id limit 2000 "
  54. sql = select_sql+" from sentences A,entity_mention_copy B,"+t+" C where B.entity_type in ('org','company') and A.doc_id=B.doc_id and A.sentence_index=B.sentence_index and B.entity_id=C.entity_id "+train_sql
  55. print(sql)
  56. cursor.execute(sql)
  57. data_x = []
  58. data_y = []
  59. data_context = []
  60. rows = cursor.fetchmany(1000)
  61. allLimit = 330000
  62. all = 0
  63. while(rows):
  64. for row in rows:
  65. if all>=allLimit:
  66. break
  67. item_x = embedding(spanWindow(tokens=row[0],begin_index=row[1],end_index=row[2],size=input_shape[1]),shape=input_shape)
  68. item_y = np.zeros(output_shape)
  69. item_y[row[3]] = 1
  70. all += 1
  71. if not isTrain:
  72. item_context = []
  73. item_context.append(row[4])
  74. data_context.append(item_context)
  75. data_x.append(item_x)
  76. data_y.append(item_y)
  77. rows = cursor.fetchmany(1000)
  78. return np.transpose(np.array(data_x),(1,0,2,3)),np.array(data_y),data_context
  79. def getBiRNNModel():
  80. '''
  81. @summary:获取模型
  82. '''
  83. L_input = layers.Input(shape=input_shape[1:],dtype="float32")
  84. #C_input = layers.Input(shape=(10,128),dtype="float32")
  85. R_input = layers.Input(shape=input_shape[1:],dtype="float32")
  86. #lstm_0 = layers.Bidirectional(layers.LSTM(16,return_sequences=True))(ThreeBilstm(0)(input))
  87. lstm_0 = layers.Bidirectional(layers.LSTM(16,return_sequences=True))(L_input)
  88. avg_0 = layers.GlobalAveragePooling1D()(lstm_0)
  89. #lstm_1 = layers.Bidirectional(layers.LSTM(16,return_sequences=True))(C_input)
  90. #avg_1 = layers.GlobalAveragePooling1D()(lstm_1)
  91. lstm_2 = layers.Bidirectional(layers.LSTM(16,return_sequences=True))(R_input)
  92. avg_2 = layers.GlobalAveragePooling1D()(lstm_2)
  93. #concat = layers.merge([avg_0,avg_1,avg_2],mode="concat")
  94. concat = layers.merge([avg_0,avg_2],mode="concat")
  95. output = layers.Dense(output_shape[0],activation="softmax")(concat)
  96. model = models.Model(inputs=[L_input,R_input],outputs=output)
  97. model.compile(optimizer=optimizers.Adam(lr=0.001),loss=losses.binary_crossentropy,metrics=[precision,recall,f1_score])
  98. return model
  99. def loadTrainData(percent=0.9):
  100. files = ["id_token_text_begin_end_label.pk","id_token_text_begin_end_label.pk1"]
  101. data_x = []
  102. data_y = []
  103. #data_id = []
  104. test_x = []
  105. test_y = []
  106. test_id = []
  107. for file in files:
  108. data = load(file)
  109. for row in data:
  110. item_x = embedding(spanWindow(tokens=row[1],begin_index=row[3],end_index=row[4],size=input_shape[1]),shape=input_shape)
  111. item_y = np.zeros(output_shape)
  112. label = int(row[5])
  113. if label not in [0,1,2,3,4,5]:
  114. continue
  115. item_y[label] = 1
  116. if np.random.random()<percent:
  117. data_x.append(item_x)
  118. data_y.append(item_y)
  119. #data_id.append(row[0])
  120. else:
  121. test_x.append(item_x)
  122. test_y.append(item_y)
  123. test_id.append(row[0])
  124. return np.transpose(np.array(data_x),(1,0,2,3)),np.array(data_y),np.transpose(np.array(test_x),(1,0,2,3)),np.array(test_y),None,test_id
  125. def training():
  126. '''
  127. @summary:训练模型
  128. '''
  129. model = getBiRNNModel()
  130. model.summary()
  131. #train_x,train_y,_ = getTokensLabels(isTrain=True,t="hand_label_role")
  132. train_x,train_y,test_x,test_y,_,test_id = loadTrainData()
  133. save([test_x,test_y,test_id],"val_data.pk")
  134. checkpoint = ModelCheckpoint(
  135. "../../dl_dev/role/log/ep{epoch:03d}-loss{loss:.3f}-val_loss{val_loss:.3f}-f1{val_f1_score:.3f}.h5", monitor="val_loss", verbose=1, save_best_only=True, mode='min')
  136. print(np.shape(train_x))
  137. history_model = model.fit(x=[train_x[0],train_x[1]],y=train_y,validation_data=([test_x[0],test_x[1]],test_y),epochs=120,batch_size=512,shuffle=True,callbacks=[checkpoint])
  138. predict_y = model.predict([test_x[0],test_x[1]])
  139. model.save(model_file)
  140. #print_metrics(history_model)
  141. def val():
  142. files = []
  143. for file in glob.glob("C:\\Users\\User\\Desktop\\20190416要素\\*.html"):
  144. filename = file.split("\\")[-1]
  145. files.append(filename)
  146. conn = psycopg2.connect(dbname="article_label",user="postgres",password="postgres",host="192.168.2.101")
  147. cursor = conn.cursor()
  148. sql = '''
  149. select A.entity_id,A.entity_text,A.begin_index,A.end_index,A.label,A.values,B.tokens,A.doc_id
  150. from entity_mention A,sentences B
  151. where A.doc_id=B.doc_id and A.sentence_index=B.sentence_index
  152. and A.entity_type in ('org','company')
  153. and A.label!='None'
  154. and not exists(select 1 from turn_label where entity_id=A.entity_id)
  155. order by A.label
  156. '''
  157. cursor.execute(sql)
  158. rows = cursor.fetchall()
  159. list_entity_id = []
  160. list_before = []
  161. list_after = []
  162. list_text = []
  163. list_label = []
  164. list_prob = []
  165. repeat = set()
  166. data_x = []
  167. cnn_x = []
  168. for row in rows:
  169. entity_id = row[0]
  170. entity_text = row[1]
  171. begin_index = row[2]
  172. end_index = row[3]
  173. label = int(row[4])
  174. values = row[5][1:-1].split(",")
  175. tokens = row[6]
  176. doc_id = row[7]
  177. if doc_id not in files:
  178. continue
  179. if float(values[label])<0.5:
  180. continue
  181. beforeafter = spanWindow(tokens, begin_index, end_index, 10,center_include=True,text=entity_text)
  182. if ("".join(beforeafter[0]),"".join(beforeafter[1]),"".join(beforeafter[2])) in repeat:
  183. continue
  184. repeat.add(("".join(beforeafter[0]),"".join(beforeafter[1]),"".join(beforeafter[2])))
  185. item_x = embedding(spanWindow(tokens=tokens,begin_index=begin_index,end_index=end_index,size=input_shape[1]),shape=input_shape)
  186. data_x.append(item_x)
  187. cnn_x.append(encodeInput(spanWindow(tokens=tokens,begin_index=begin_index,end_index=end_index,size=10,center_include=True,word_flag=True,text=entity_text), word_len=50, word_flag=True))
  188. list_entity_id.append(entity_id)
  189. list_before.append("".join(beforeafter[0]))
  190. list_after.append("".join(beforeafter[2]))
  191. list_text.append("".join(beforeafter[1]))
  192. list_label.append(label)
  193. list_prob.append(values[label])
  194. model = models.load_model("../../dl_dev/role/log/new_biLSTM-ep012-loss0.028-val_loss0.040-f10.954.h5", custom_objects={"precision":precision, "recall":recall, "f1_score":f1_score})
  195. data_x = np.transpose(np.array(data_x),(1,0,2,3))
  196. predict_value = model.predict([data_x[0],data_x[1]])
  197. predict_y = np.argmax(predict_value,1)
  198. list_newprob = []
  199. for label,value in zip(predict_y,predict_value):
  200. list_newprob.append(value[label])
  201. print("len",len(list_entity_id))
  202. model_cnn = models.load_model("../../dl_dev/role/log/ep071-loss0.107-val_loss0.122-f10.956.h5", custom_objects={"precision":precision, "recall":recall, "f1_score":f1_score})
  203. cnn_x = np.transpose(np.array(cnn_x),(1,0,2))
  204. predict_value = model_cnn.predict([cnn_x[0],cnn_x[1],cnn_x[2]])
  205. predict_y_cnn = np.argmax(predict_value,1)
  206. list_newprob_cnn = []
  207. for label,value in zip(predict_y_cnn,predict_value):
  208. list_newprob_cnn.append(value[label])
  209. print("len",len(list_entity_id))
  210. data = []
  211. for id,before,text,after,label_bi,prob,label_bi1,newprob,label_cnn,newprob_cnn in zip(list_entity_id,list_before,list_text,list_after,list_label,list_prob,predict_y,list_newprob,predict_y_cnn,list_newprob_cnn):
  212. if label_bi1!=label_cnn:
  213. data.append([id,before,text,after,label_bi,prob,label_bi1,newprob,label_cnn,newprob_cnn])
  214. data.sort(key=lambda x:x[6])
  215. list_entity_id = []
  216. list_before = []
  217. list_after = []
  218. list_text = []
  219. list_label = []
  220. list_prob = []
  221. list_newlabel = []
  222. list_newprob = []
  223. list_newlabel_cnn = []
  224. list_newprob_cnn = []
  225. for item in data:
  226. list_entity_id.append(item[0])
  227. list_before.append(item[1])
  228. list_text.append(item[2])
  229. list_after.append(item[3])
  230. list_label.append(item[4])
  231. list_prob.append(item[5])
  232. list_newlabel.append(item[6])
  233. list_newprob.append(item[7])
  234. list_newlabel_cnn.append(item[8])
  235. list_newprob_cnn.append(item[9])
  236. parts = 1
  237. parts_num = len(list_entity_id)//parts
  238. for i in range(parts-1):
  239. data = {"entity_id":list_entity_id[i*parts_num:(i+1)*parts_num],"list_before":list_before[i*parts_num:(i+1)*parts_num],"list_after":list_after[i*parts_num:(i+1)*parts_num],"list_text":list_text[i*parts_num:(i+1)*parts_num],"list_label":list_label[i*parts_num:(i+1)*parts_num],"list_prob":list_prob[i*parts_num:(i+1)*parts_num]}
  240. df = pd.DataFrame(data)
  241. df.to_excel("未标注错误_"+str(i)+".xls",columns=["entity_id","list_before","list_text","list_after","list_label","list_prob"])
  242. i = parts - 1
  243. data = {"entity_id":list_entity_id[i*parts_num:],"list_before":list_before[i*parts_num:],"list_after":list_after[i*parts_num:],"list_text":list_text[i*parts_num:],"list_label":list_label[i*parts_num:],"list_prob":list_prob[i*parts_num:],"list_newlabel":list_newlabel[i*parts_num:],"list_newprob":list_newprob[i*parts_num:],"list_newlabel_cnn":list_newlabel_cnn[i*parts_num:],"list_newprob_cnn":list_newprob_cnn[i*parts_num:]}
  244. df = pd.DataFrame(data)
  245. df.to_excel("测试数据_role-cnnw-biw"+str(i)+".xls",columns=["entity_id","list_before","list_text","list_after","list_label","list_prob","list_newlabel","list_newprob","list_newlabel_cnn","list_newprob_cnn"])
  246. def validation():
  247. conn1 = psycopg2.connect(dbname="article_label",user="postgres",password="postgres",host="192.168.2.101")
  248. conn2 = psycopg2.connect(dbname="BiddingKM_test_10000",user="postgres",password="postgres",host="192.168.2.101")
  249. cursor1 = conn1.cursor()
  250. cursor2 = conn2.cursor()
  251. model = getBiRNNModel()
  252. model.load_weights("log/ep010-loss0.033-val_loss0.043-f10.950.h5")
  253. [test_x,test_y,test_id] = load("val_data.pk")
  254. predict_y = model.predict([test_x[0],test_x[1]])
  255. list_id = []
  256. list_before = []
  257. list_text = []
  258. list_after = []
  259. list_same = []
  260. list_predict = []
  261. list_label = []
  262. data = []
  263. for id,predict,label in zip(test_id,np.argmax(predict_y,1),np.argmax(test_y,1)):
  264. if predict==label:
  265. same = 0
  266. text = ""
  267. beforeafter = [[],[]]
  268. else:
  269. same = 1
  270. if re.search("比地",id) is not None:
  271. sql = " select A.tokens,B.entity_text,B.begin_index,B.end_index from sentences A,entity_mention B where A.doc_id=B.doc_id and A.sentence_index=B.sentence_index and B.entity_id='"+id+"' "
  272. cursor1.execute(sql)
  273. rows = cursor1.fetchall()
  274. else:
  275. sql = " select A.tokens,B.entity_text,B.begin_index,B.end_index from sentences A,entity_mention_copy B where A.doc_id=B.doc_id and A.sentence_index=B.sentence_index and B.entity_id='"+id+"' "
  276. cursor2.execute(sql)
  277. rows = cursor2.fetchall()
  278. retu = rows[0]
  279. text = retu[1]
  280. beforeafter = spanWindow(retu[0], retu[2], retu[3], 10)
  281. data.append([id,same,"".join(beforeafter[0]),text,"".join(beforeafter[1]),label,predict])
  282. data.sort(key=lambda x:x[1])
  283. for item in data:
  284. list_id.append(item[0])
  285. list_same.append(item[1])
  286. list_before.append(item[2])
  287. list_text.append(item[3])
  288. list_after.append(item[4])
  289. list_label.append(item[5])
  290. list_predict.append(item[6])
  291. df = pd.DataFrame({"list_id":list_id,"list_same":list_same,"list_before":list_before,"list_text":list_text,"list_after":list_after,"list_label":list_label,"list_predict":list_predict})
  292. columns = ["list_id","list_same","list_before","list_text","list_after","list_label","list_predict"]
  293. df.to_excel("result.xls",index=False,columns=columns)
  294. conn1.close()
  295. conn2.close()
  296. def trainingIteration_category(iterate=2,label_table=sourcetable):
  297. '''
  298. @summary: 迭代训练模型,修改标签,适用于当数据准确率不高的条件
  299. @param:
  300. iterate:迭代次数
  301. label_table:标签数据所在表
  302. '''
  303. def getDatasets():
  304. conn = psycopg2.connect(dbname="BiddingKM_test_10000",user="postgres",password="postgres",host="192.168.2.101")
  305. cursor = conn.cursor()
  306. select_sql = " select A.tokens,B.begin_index,B.end_index,C.label,C.entity_id "
  307. sql = select_sql+" from sentences A,entity_mention B,"+label_table+" C where A.doc_id=B.doc_id and A.sentence_index=B.sentence_index and B.entity_id=C.entity_id order by A.doc_id "
  308. cursor.execute(sql)
  309. print(sql)
  310. data_x = []
  311. data_y = []
  312. id_set = []
  313. rows = cursor.fetchmany(1000)
  314. allLimit = 320000
  315. all = 0
  316. while(rows):
  317. for row in rows:
  318. if all>=allLimit:
  319. break
  320. item_x = embedding(spanWindow(tokens=row[0],begin_index=row[1],end_index=row[2]))
  321. item_y = np.zeros(output_shape)
  322. item_y[row[3]] = 1
  323. all += 1
  324. data_x.append(item_x)
  325. data_y.append(item_y)
  326. id_set.append(row[4])
  327. rows = cursor.fetchmany(1000)
  328. return np.transpose(np.array(data_x),(1,0,2,3)),np.array(data_y),id_set
  329. train_x,train_y,id_set = getDatasets()
  330. alllength = len(train_x[0])
  331. parts = 6
  332. num_parts = alllength//parts
  333. copy_y = copy.copy(train_y)
  334. for ite in range(iterate):
  335. for j in range(parts-1):
  336. print("iterate:",str(ite)+"/"+str(iterate-1),str(j)+"/"+str(parts-1))
  337. model = getBiRNNModel()
  338. model.summary()
  339. test_begin = j*num_parts
  340. test_end = (j+1)*num_parts
  341. checkpoint = ModelCheckpoint(model_file+".hdf5",monitor="val_loss",verbose=1,save_best_only=True,mode='min')
  342. history_model = model.fit(x=[np.concatenate((train_x[0][0:test_begin],train_x[0][test_end:])),np.concatenate((train_x[1][0:test_begin],train_x[1][test_end:]))],y=np.concatenate((copy_y[0:test_begin],copy_y[test_end:])),validation_data=([train_x[0][test_begin:test_end],train_x[1][test_begin:test_end]],copy_y[test_begin:test_end]),epochs=30,batch_size=300,shuffle=True,callbacks=[checkpoint])
  343. model.load_weights(model_file+".hdf5")
  344. predict_y = model.predict([train_x[0][test_begin:test_end],train_x[1][test_begin:test_end]])
  345. for i in range(len(predict_y)):
  346. if np.max(predict_y[i])>=0.8:
  347. max_index = np.argmax(predict_y[i])
  348. for h in range(len(predict_y[i])):
  349. if h==max_index:
  350. copy_y[i+test_begin][h] = 1
  351. else:
  352. copy_y[i+test_begin][h] = 0
  353. print("iterate:",str(ite)+"/"+str(iterate-1),str(j)+"/"+str(parts-1))
  354. model = getBiRNNModel()
  355. model.summary()
  356. test_begin = j*num_parts
  357. checkpoint = ModelCheckpoint(model_file+".hdf5",monitor="val_loss",verbose=1,save_best_only=True,mode="min")
  358. history_model = model.fit(x=[train_x[0][0:test_begin],train_x[1][0:test_begin]],y=copy_y[0:test_begin],validation_data=([train_x[0][test_begin:],train_x[1][test_begin:]],copy_y[test_begin:]),epochs=30,batch_size=300,shuffle=True,callbacks=[checkpoint])
  359. model.load_weights(model_file+".hdf5")
  360. predict_y = model.predict([train_x[0][test_begin:],train_x[1][test_begin:]])
  361. for i in range(len(predict_y)):
  362. if np.max(predict_y[i])>=0.8:
  363. max_index = np.argmax(predict_y[i])
  364. for h in range(len(predict_y[i])):
  365. if h==max_index:
  366. copy_y[i+test_begin][h] = 1
  367. else:
  368. copy_y[i+test_begin][h] = 0
  369. with codecs.open("final_label_"+domain+".txt","w",encoding="utf8") as f:
  370. for i in range(len(id_set)):
  371. f.write(id_set[i])
  372. f.write("\t")
  373. f.write(str(np.argmax(copy_y[i])))
  374. f.write("\n")
  375. f.flush()
  376. f.close()
  377. def predict():
  378. '''
  379. @summary: 预测测试数据
  380. '''
  381. test_x,_,ids = getTokensLabels("final_label_role", isTrain=False,predict=True)
  382. model = models.load_model(model_file,custom_objects={'precision':precision,'recall':recall,'f1_score':f1_score})
  383. predict_y = model.predict([test_x[0],test_x[1]])
  384. with codecs.open("test_predict_"+domain+".txt","w",encoding="utf8") as f:
  385. for i in range(len(predict_y)):
  386. f.write(ids[i][0])
  387. f.write("\t")
  388. f.write(str(np.argmax(predict_y[i])))
  389. f.write("\t")
  390. value = ""
  391. for item in predict_y[i]:
  392. value += str(item)+","
  393. f.write(value[:-1])
  394. f.write("\n")
  395. f.flush()
  396. f.close()
  397. def importIterateLabel():
  398. '''
  399. @summary:导入迭代之后的标签值
  400. '''
  401. file = "final_label_"+domain+".txt"
  402. conn = psycopg2.connect(dbname="BiddingKM_test_10000",user="postgres",password="postgres",host="192.168.2.101")
  403. cursor = conn.cursor()
  404. tablename = file.split(".")[0]
  405. # 创建表
  406. cursor.execute(" SELECT to_regclass('"+tablename+"') is null ")
  407. flag = cursor.fetchall()[0][0]
  408. if flag:
  409. cursor.execute(" create table "+tablename+"(entity_id text,label int)")
  410. else:
  411. cursor.execute(" delete from "+tablename)
  412. with codecs.open(file,"r",encoding="utf8") as f:
  413. while(True):
  414. line = f.readline()
  415. if not line:
  416. break
  417. line_split = line.split("\t")
  418. entity_id=line_split[0]
  419. label = line_split[1]
  420. sql = " insert into "+tablename+"(entity_id,label) values('"+str(entity_id)+"',"+str(label)+")"
  421. cursor.execute(sql)
  422. f.close()
  423. conn.commit()
  424. conn.close()
  425. def importtestPredict():
  426. '''
  427. @summary:导入测试数据的预测值
  428. '''
  429. file = "test_predict_"+domain+".txt"
  430. conn = psycopg2.connect(dbname="BiddingKG",user="postgres",password="postgres",host="192.168.2.101")
  431. cursor = conn.cursor()
  432. tablename = file.split(".")[0]
  433. # 创建表
  434. cursor.execute(" SELECT to_regclass('"+tablename+"') is null ")
  435. flag = cursor.fetchall()[0][0]
  436. if flag:
  437. cursor.execute(" create table "+tablename+"(entity_id text,label int,value text)")
  438. else:
  439. cursor.execute(" delete from "+tablename)
  440. with codecs.open(file,"r",encoding="utf8") as f:
  441. while(True):
  442. line = f.readline()
  443. if not line:
  444. break
  445. line_split = line.split("\t")
  446. entity_id=line_split[0]
  447. predict = line_split[1]
  448. value = line_split[2]
  449. sql = " insert into "+tablename+"(entity_id,label,value) values('"+str(entity_id)+"',"+str(predict)+",'"+str(value)+"')"
  450. cursor.execute(sql)
  451. f.close()
  452. conn.commit()
  453. conn.close()
  454. def autoIterate():
  455. #trainingIteration_binary()
  456. trainingIteration_category()
  457. importIterateLabel()
  458. training()
  459. predict()
  460. def test1(entity_id):
  461. conn = psycopg2.connect(dbname="article_label",user="postgres",password="postgres",host="192.168.2.101")
  462. cursor = conn.cursor()
  463. if predict:
  464. sql = "select A.tokens,B.begin_index,B.end_index,0,B.entity_id from sentences A,entity_mention B where B.entity_type in ('org','company') and A.doc_id=B.doc_id and A.sentence_index=B.sentence_index and B.entity_id='"+entity_id+"'"
  465. print(sql)
  466. cursor.execute(sql)
  467. data_x = []
  468. data_y = []
  469. rows = cursor.fetchmany(1000)
  470. while(rows):
  471. for row in rows:
  472. item_x = encodeInput(spanWindow(tokens=row[0],begin_index=row[1],end_index=row[2],size=10,center_include=True,word_flag=True), word_len=50, word_flag=True)
  473. item_y = np.zeros(output_shape)
  474. item_y[row[3]] = 1
  475. data_x.append(item_x)
  476. data_y.append(item_y)
  477. rows = cursor.fetchmany(1000)
  478. model = models.load_model("../../dl_dev/role/log/ep017-loss0.088-val_loss0.125-f10.955.h5", custom_objects={'precision':precision, 'recall':recall, 'f1_score':f1_score})
  479. test_x = np.transpose(np.array(data_x),(1,0,2))
  480. predict_y = model.predict([test_x[0],test_x[1],test_x[2]])
  481. print(predict_y)
  482. if __name__=="__main__":
  483. #training()
  484. val()
  485. #validation()
  486. #test()
  487. #trainingIteration_category()
  488. #importIterateLabel()
  489. #predict()
  490. #importtestPredict()
  491. #autoIterate()
  492. #test1("比地_101_61333318.html_0_116_122")