IterateModeling_LR.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336
  1. #coding:utf8
  2. import psycopg2
  3. from keras import models
  4. from keras import layers
  5. from keras import optimizers,losses,metrics
  6. from keras.callbacks import ModelCheckpoint
  7. import codecs
  8. import copy
  9. from BiddingKG.dl.common.Utils import *
  10. #sourcetable = "label_guest_money"
  11. sourcetable = "hand_label_money"
  12. domain = sourcetable.split("_")[2]
  13. model_file = "model_"+domain+".model"
  14. entity_type = "money"
  15. input_shape = (2,10,128)
  16. input2_shape = [7]
  17. output_shape = [3]
  18. def getTokensLabels(t,isTrain=True,predict=False):
  19. '''
  20. @param:
  21. t:标注数据所在表
  22. isTrain:是否训练
  23. predict:是否是验证
  24. @return:返回标注数据的处理后的输入和标签
  25. '''
  26. conn = psycopg2.connect(dbname="BiddingKM_test_10000",user="postgres",password="postgres",host="192.168.2.101")
  27. cursor = conn.cursor()
  28. if predict:
  29. sql = '''
  30. select A.tokens,B.begin_index,B.end_index,0,B.entity_text,B.entity_id from sentences A,entity_mention B where B.entity_type in ('money') and A.doc_id=B.doc_id and A.sentence_index=B.sentence_index
  31. and B.doc_id='360e1a62-7f82-11e8-ae30-ecf4bbc56acd' order by B.doc_id
  32. '''
  33. else:
  34. select_sql = " select A.tokens,B.begin_index,B.end_index,C.label,B.entity_text,C.entity_id "
  35. '''
  36. if isTrain:
  37. train_sql = " and C.id not in(select variable_id from dd_graph_variables_holdout) "
  38. else:
  39. train_sql = " and C.id in(select variable_id from dd_graph_variables_holdout)"
  40. '''
  41. if isTrain:
  42. train_sql = " and A.doc_id not in(select id from articles_processed order by id limit 1000) "
  43. else:
  44. train_sql = " and A.doc_id in(select id from articles_processed order by id limit 1000)"
  45. sql = select_sql+" from sentences A,entity_mention_copy B,"+t+" C where B.entity_type='"+entity_type+"' and A.doc_id=B.doc_id and A.sentence_index=B.sentence_index and B.entity_id=C.entity_id "+train_sql
  46. cursor.execute(sql)
  47. print(sql)
  48. data_x = []
  49. data_x1 = []
  50. data_y = []
  51. data_context = []
  52. rows = cursor.fetchmany(1000)
  53. allLimit = 250000
  54. all = 0
  55. while(rows):
  56. for row in rows:
  57. if all>=allLimit:
  58. break
  59. item_x = embedding(spanWindow(tokens=row[0],begin_index=row[1],end_index=row[2],size=input_shape[1]),shape=input_shape)
  60. item_x1 = partMoney(row[4])
  61. item_y = np.zeros(output_shape)
  62. item_y[row[3]] = 1
  63. all += 1
  64. if not isTrain:
  65. item_context = []
  66. item_context.append(row[5])
  67. data_context.append(item_context)
  68. data_x.append(item_x)
  69. data_x1.append(item_x1)
  70. data_y.append(item_y)
  71. rows = cursor.fetchmany(1000)
  72. return np.transpose(np.array(data_x),(1,0,2,3)),np.array(data_x1),np.array(data_y),data_context
  73. def getBiRNNModel():
  74. '''
  75. @summary:获取模型
  76. '''
  77. L_input = layers.Input(shape=input_shape[1:],dtype="float32")
  78. C_input = layers.Input(shape=([input2_shape[0]]),dtype="float32")
  79. R_input = layers.Input(shape=input_shape[1:],dtype="float32")
  80. #lstm_0 = layers.Bidirectional(layers.LSTM(16,return_sequences=True))(ThreeBilstm(0)(input))
  81. lstm_0 = layers.Bidirectional(layers.LSTM(16,return_sequences=True))(L_input)
  82. avg_0 = layers.GlobalAveragePooling1D()(lstm_0)
  83. C_embed = layers.Dense(4,activation="sigmoid")(C_input)
  84. #lstm_1 = layers.Bidirectional(layers.LSTM(16,return_sequences=True))(C_input)
  85. #avg_1 = layers.GlobalAveragePooling1D()(lstm_1)
  86. lstm_2 = layers.Bidirectional(layers.LSTM(16,return_sequences=True))(R_input)
  87. avg_2 = layers.GlobalAveragePooling1D()(lstm_2)
  88. #concat = layers.merge([avg_0,avg_1,avg_2],mode="concat")
  89. concat = layers.merge([avg_0,C_embed,avg_2],mode="concat")
  90. output = layers.Dense(output_shape[0],activation="softmax")(concat)
  91. model = models.Model(inputs=[L_input,C_input,R_input],outputs=output)
  92. model.compile(optimizer=optimizers.RMSprop(lr=0.001),loss=losses.binary_crossentropy,metrics=[precision,recall,f1_score])
  93. return model
  94. def training():
  95. '''
  96. @summary:训练模型
  97. '''
  98. model = getBiRNNModel()
  99. model.summary()
  100. train_x,train_x1,train_y,_ = getTokensLabels(isTrain=True,t=sourcetable)
  101. print(np.shape(train_x))
  102. test_x,test_x1,test_y,test_context = getTokensLabels(isTrain=False,t=sourcetable)
  103. checkpoint = ModelCheckpoint(model_file+".hdf5",monitor="val_loss",verbose=1,save_best_only=True,mode='min')
  104. history_model = model.fit(x=[train_x[0],train_x1,train_x[1]],y=train_y,validation_data=([test_x[0],test_x1,test_x[1]],test_y),epochs=100,batch_size=300,shuffle=True,callbacks=[checkpoint])
  105. predict_y = model.predict([test_x[0],test_x1,test_x[1]])
  106. with codecs.open("predict.txt","w",encoding="utf8") as f:
  107. for i in range(len(predict_y)):
  108. f.write(str(test_context[i][0]))
  109. f.write("\t")
  110. f.write(str(np.argmax(predict_y[i])))
  111. f.write("\n")
  112. f.flush()
  113. f.close()
  114. model.save(model_file)
  115. #print_metrics(history_model)
  116. def trainingIteration_category(iterate=2,label_table="label_guest_money"):
  117. '''
  118. @summary:迭代训练模型,修改标签,适用于当数据准确率不高的条件
  119. @param:
  120. iterate:迭代次数
  121. label_table:标签数据所在表
  122. '''
  123. def getDatasets():
  124. conn = psycopg2.connect(dbname="BiddingKM_test_10000",user="postgres",password="postgres",host="192.168.2.101")
  125. cursor = conn.cursor()
  126. select_sql = " select A.tokens,B.begin_index,B.end_index,C.label,B.entity_text,C.entity_id "
  127. sql = select_sql+" from sentences A,entity_mention B,"+label_table+" C where B.entity_type='"+entity_type+"' and A.doc_id=B.doc_id and A.sentence_index=B.sentence_index and B.entity_id=C.entity_id order by A.doc_id "
  128. cursor.execute(sql)
  129. print(sql)
  130. data_x = []
  131. data_x1 = []
  132. data_y = []
  133. id_set = []
  134. rows = cursor.fetchmany(1000)
  135. allLimit = 250000
  136. all = 0
  137. while(rows):
  138. for row in rows:
  139. if all>=allLimit:
  140. break
  141. item_x = embedding(spanWindow(tokens=row[0],begin_index=row[1],end_index=row[2]))
  142. item_x1 = partMoney(row[4])
  143. item_y = np.zeros(output_shape)
  144. item_y[row[3]] = 1
  145. all += 1
  146. data_x.append(item_x)
  147. data_x1.append(item_x1)
  148. data_y.append(item_y)
  149. id_set.append(row[5])
  150. rows = cursor.fetchmany(1000)
  151. return np.transpose(np.array(data_x),(1,0,2,3)),np.array(data_x1),np.array(data_y),id_set
  152. train_x,train_x1,train_y,id_set = getDatasets()
  153. alllength = len(train_x[0])
  154. parts = 8
  155. num_parts = alllength//parts
  156. copy_y = copy.copy(train_y)
  157. for ite in range(iterate):
  158. for j in range(parts-1):
  159. print("iterate:",str(ite)+"/"+str(iterate-1),str(j)+"/"+str(parts-1))
  160. model = getBiRNNModel()
  161. model.summary()
  162. test_begin = j*num_parts
  163. test_end = (j+1)*num_parts
  164. checkpoint = ModelCheckpoint(model_file+".hdf5",monitor="val_loss",verbose=1,save_best_only=True,mode='min')
  165. history_model = model.fit(x=[np.concatenate((train_x[0][0:test_begin],train_x[0][test_end:])),np.concatenate((train_x1[0:test_begin],train_x1[test_end:])),np.concatenate((train_x[1][0:test_begin],train_x[1][test_end:]))],y=np.concatenate((copy_y[0:test_begin],copy_y[test_end:])),validation_data=([train_x[0][test_begin:test_end],train_x1[test_begin:test_end],train_x[1][test_begin:test_end]],copy_y[test_begin:test_end]),epochs=30,batch_size=300,shuffle=True,callbacks=[checkpoint])
  166. model.load_weights(model_file+".hdf5")
  167. predict_y = model.predict([train_x[0][test_begin:test_end],train_x1[test_begin:test_end],train_x[1][test_begin:test_end]])
  168. for i in range(len(predict_y)):
  169. if np.max(predict_y[i])>=0.8:
  170. max_index = np.argmax(predict_y[i])
  171. for h in range(len(predict_y[i])):
  172. if h==max_index:
  173. copy_y[i+test_begin][h] = 1
  174. else:
  175. copy_y[i+test_begin][h] = 0
  176. print("iterate:",str(ite)+"/"+str(iterate-1),str(j)+"/"+str(parts-1))
  177. model = getBiRNNModel()
  178. model.summary()
  179. test_begin = j*num_parts
  180. checkpoint = ModelCheckpoint(model_file+".hdf5",monitor="val_loss",verbose=1,save_best_only=True,mode='min')
  181. history_model = model.fit(x=[train_x[0][0:test_begin],train_x1[0:test_begin],train_x[1][0:test_begin]],y=copy_y[0:test_begin],validation_data=([train_x[0][test_begin:],train_x1[test_begin:],train_x[1][test_begin:]],copy_y[test_begin:]),epochs=30,batch_size=300,shuffle=True,callbacks=[checkpoint])
  182. model.load_weights(model_file+".hdf5")
  183. predict_y = model.predict([train_x[0][test_begin:],train_x1[test_begin:],train_x[1][test_begin:]])
  184. for i in range(len(predict_y)):
  185. if np.max(predict_y[i])>=0.9:
  186. max_index = np.argmax(predict_y[i])
  187. for h in range(len(predict_y[i])):
  188. if h==max_index:
  189. copy_y[i+test_begin][h] = 1
  190. else:
  191. copy_y[i+test_begin][h] = 0
  192. #把结果写入一个文件中
  193. with codecs.open("final_label_"+domain+".txt","w",encoding="utf8") as f:
  194. for i in range(len(id_set)):
  195. f.write(id_set[i])
  196. f.write("\t")
  197. f.write(str(np.argmax(copy_y[i])))
  198. f.write("\n")
  199. f.flush()
  200. f.close()
  201. def predict():
  202. '''
  203. @summary:预测测试数据
  204. '''
  205. test_x,text_x1,_,ids = getTokensLabels(sourcetable, isTrain=False,predict="True")
  206. model = getBiRNNModel()
  207. model.load_weights(model_file+".hdf5")
  208. predict_y = model.predict([test_x[0],text_x1,test_x[1]])
  209. with codecs.open("test_predict_"+domain+".txt","w",encoding="utf8") as f:
  210. for i in range(len(predict_y)):
  211. f.write(ids[i][0])
  212. f.write("\t")
  213. f.write(str(np.argmax(predict_y[i])))
  214. f.write("\t")
  215. value = ""
  216. for item in predict_y[i]:
  217. value += str(item)+","
  218. f.write(value[:-1])
  219. f.write("\n")
  220. f.flush()
  221. f.close()
  222. def importIterateLabel():
  223. '''
  224. @summary:导入迭代之后的标签值
  225. '''
  226. file = "final_label_"+domain+".txt"
  227. conn = psycopg2.connect(dbname="BiddingKM_test_10000",user="postgres",password="postgres",host="192.168.2.101")
  228. cursor = conn.cursor()
  229. tablename = file.split(".")[0]
  230. # 创建表
  231. cursor.execute(" SELECT to_regclass('"+tablename+"') is null ")
  232. flag = cursor.fetchall()[0][0]
  233. if flag:
  234. cursor.execute(" create table "+tablename+"(entity_id text,label int)")
  235. else:
  236. cursor.execute(" delete from "+tablename)
  237. with codecs.open(file,"r",encoding="utf8") as f:
  238. while(True):
  239. line = f.readline()
  240. if not line:
  241. break
  242. line_split = line.split("\t")
  243. entity_id=line_split[0]
  244. label = line_split[1]
  245. sql = " insert into "+tablename+"(entity_id,label) values('"+str(entity_id)+"',"+str(label)+")"
  246. cursor.execute(sql)
  247. f.close()
  248. conn.commit()
  249. conn.close()
  250. def importtestPredict():
  251. '''
  252. @summary:导入测试数据的预测值
  253. '''
  254. file = "test_predict_"+domain+".txt"
  255. conn = psycopg2.connect(dbname="BiddingKG",user="postgres",password="postgres",host="192.168.2.101")
  256. cursor = conn.cursor()
  257. tablename = file.split(".")[0]
  258. # 创建表
  259. cursor.execute(" SELECT to_regclass('"+tablename+"') is null ")
  260. flag = cursor.fetchall()[0][0]
  261. if flag:
  262. cursor.execute(" create table "+tablename+"(entity_id text,label int,value text)")
  263. else:
  264. cursor.execute(" delete from "+tablename)
  265. with codecs.open(file,"r",encoding="utf8") as f:
  266. while(True):
  267. line = f.readline()
  268. if not line:
  269. break
  270. line_split = line.split("\t")
  271. entity_id=line_split[0]
  272. predict = line_split[1]
  273. value = line_split[2]
  274. sql = " insert into "+tablename+"(entity_id,label,value) values('"+str(entity_id)+"',"+str(predict)+",'"+str(value)+"')"
  275. cursor.execute(sql)
  276. f.close()
  277. conn.commit()
  278. conn.close()
  279. def autoIterate():
  280. #trainingIteration_binary()
  281. trainingIteration_category()
  282. importIterateLabel()
  283. training()
  284. predict()
  285. if __name__ == "__main__":
  286. training()
  287. #trainingIteration_category()
  288. #importIterateLabel()
  289. #predict()
  290. #importtestPredict()
  291. #autoIterate()