train.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239
  1. #from general_data import getTokensLabels
  2. import sys
  3. import os
  4. sys.path.append(os.path.abspath("../.."))
  5. # from model import *
  6. from keras.callbacks import ModelCheckpoint
  7. from keras import layers,models,optimizers,losses
  8. import psycopg2
  9. from BiddingKG.dl.common.Utils import *
  10. import pandas as pd
  11. from BiddingKG.dl.common.models import *
  12. # os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
  13. # os.environ["CUDA_VISIBLE_DEVICES"] = ""
  14. sourcetable = "label_guest_person"
  15. domain = sourcetable.split("_")[2]
  16. model_file = "model_"+domain+".model"
  17. input_shape = (2,10,128)
  18. output_shape = [5]
  19. def getTokensLabels(t,isTrain=True):
  20. '''
  21. @summary: 取得模型的输入输出数据
  22. @param:
  23. t:标签数据所在表
  24. @return: type:array,array,list meaning:输入,输出,实体id
  25. '''
  26. conn = psycopg2.connect(dbname="BiddingKG",user="postgres",password="postgres",host="192.168.2.101")
  27. cursor = conn.cursor()
  28. if isTrain:
  29. sql = " select B.tokens,A.begin_index,A.end_index,C.label,A.entity_id from train_entity_copy A,train_sentences_copy B,"+t+" C where A.doc_id=B.doc_id and A.sentence_index=B.sentence_index and A.entity_type='person' and A.entity_id=C.entity_id and C.entity_id not in (select entity_id from "+t+" order by entity_id limit 2000)"
  30. else:
  31. sql = " select B.tokens,A.begin_index,A.end_index,C.label,A.entity_id from train_entity_copy A,train_sentences_copy B,"+t+" C where A.doc_id=B.doc_id and A.sentence_index=B.sentence_index and A.entity_type='person' and A.entity_id=C.entity_id and C.entity_id in (select entity_id from "+t+" order by entity_id limit 2000)"
  32. cursor.execute(sql)
  33. print(sql)
  34. data_x = []
  35. data_y = []
  36. data_context = []
  37. rows = cursor.fetchmany(1000)
  38. allLimit = 250000
  39. all = 0
  40. i = 0
  41. while(rows):
  42. for row in rows:
  43. if all>=allLimit:
  44. break
  45. item_x = embedding(spanWindow(tokens=row[0],begin_index=row[1],end_index=row[2],size=input_shape[1]),shape=input_shape)
  46. # item_x = encodeInput(spanWindow(tokens=row[0],begin_index=row[1],end_index=row[2],size=10), word_len=50, word_flag=True,userFool=False)
  47. # _span = spanWindow(tokens=row[0],begin_index=row[1],end_index=row[2],size=10,word_flag=False)
  48. # item_x = encodeInput(_span, word_len=10, word_flag=False,userFool=False)
  49. item_y = np.zeros(output_shape)
  50. item_y[row[3]] = 1
  51. all += 1
  52. if not isTrain:
  53. item_context = []
  54. item_context.append(row[4])
  55. data_context.append(item_context)
  56. data_x.append(item_x)
  57. data_y.append(item_y)
  58. i += 1
  59. rows = cursor.fetchmany(1000)
  60. return np.transpose(np.array(data_x),(1,0,2,3)),np.array(data_y),data_context
  61. def getBiRNNModel():
  62. '''
  63. @summary: 获得模型
  64. '''
  65. L_input = layers.Input(shape=input_shape[1:],dtype="float32")
  66. #C_input = layers.Input(shape=(10,128),dtype="float32")
  67. R_input = layers.Input(shape=input_shape[1:],dtype="float32")
  68. #lstm_0 = layers.Bidirectional(layers.LSTM(16,return_sequences=True))(ThreeBilstm(0)(input))
  69. lstm_0 = layers.Bidirectional(layers.LSTM(32,return_sequences=True))(L_input)
  70. avg_0 = layers.GlobalAveragePooling1D()(lstm_0)
  71. #lstm_1 = layers.Bidirectional(layers.LSTM(16,return_sequences=True))(C_input)
  72. #avg_1 = layers.GlobalAveragePooling1D()(lstm_1)
  73. lstm_2 = layers.Bidirectional(layers.LSTM(32,return_sequences=True))(R_input)
  74. avg_2 = layers.GlobalAveragePooling1D()(lstm_2)
  75. #concat = layers.merge([avg_0,avg_1,avg_2],mode="concat")
  76. concat = layers.merge([avg_0,avg_2],mode="concat")
  77. output = layers.Dense(output_shape[0],activation="softmax")(concat)
  78. model = models.Model(inputs=[L_input,R_input],outputs=output)
  79. model.compile(optimizer=optimizers.Adam(lr=0.0005),loss=losses.binary_crossentropy,metrics=[precision,recall,f1_score])
  80. return model
  81. def training():
  82. '''
  83. @summary: 训练模型
  84. '''
  85. model = getBiRNNModel()
  86. model.summary()
  87. train_x,train_y,_ = getTokensLabels(isTrain=True,t="hand_label_person")
  88. #print(np.shape(train_x))
  89. test_x,test_y,test_context = getTokensLabels(isTrain=False,t="hand_label_person")
  90. checkpoint = ModelCheckpoint(model_file+".hdf5",monitor="val_loss",verbose=1,save_best_only=True,mode='min')
  91. history_model = model.fit(x=[train_x[0],train_x[1]],y=train_y,validation_data=([test_x[0],test_x[1]],test_y),epochs=100,batch_size=256,shuffle=True,callbacks=[checkpoint])
  92. # predict_y = model.predict([test_x[0],test_x[1]])
  93. #
  94. # conn = psycopg2.connect(dbname='BiddingKG', user='postgres',password='postgres',host='192.168.2.101')
  95. # cursor = conn.cursor()
  96. # table = 'predict_person'
  97. # cursor.execute(" select to_regclass('"+table+"') is null ")
  98. # notExists = cursor.fetchall()[0][0]
  99. # if notExists:
  100. # cursor.execute(" create table "+table+" (entity_id text,predect int,label int)")
  101. # else:
  102. # cursor.execute(" delete from "+table)
  103. #
  104. #
  105. #
  106. # with open("predict.txt","w",encoding="utf8") as f:
  107. # for i in range(len(predict_y)):
  108. # if np.argmax(predict_y[i]) != np.argmax(test_y[i]):
  109. # f.write("\n")
  110. # f.write(str(test_context[i][0]))
  111. # f.write("\t")
  112. # f.write(str(np.argmax(predict_y[i])))
  113. # f.write("\t")
  114. # f.write(str(np.argmax(test_y[i])))
  115. # f.write("\n")
  116. # sql = " insert into "+table+"(entity_id ,predect ,label) values('"+str(test_context[i][0])+"','"+str(int(np.argmax(predict_y[i])))+"','"+str(int(np.argmax(test_y[i])))+"')"
  117. # # print(sql)
  118. # cursor.execute(sql)
  119. # conn.commit()
  120. # cursor.close()
  121. # conn.close()
  122. # f.flush()
  123. # f.close()
  124. #print_metrics(history_model)
  125. def train():
  126. train_x,train_y,_ = getTokensLabels(isTrain=True,t="hand_label_person")
  127. test_x,test_y,test_context = getTokensLabels(isTrain=False,t="hand_label_person")
  128. with tf.Session() as sess:
  129. vocab,matrix = getVocabAndMatrix(getModel_w2v(),Embedding_size=128)
  130. model = getBiLSTMModel(input_shape=(2,10,128), vocab=vocab, embedding_weights=matrix, classes=4)
  131. callback = ModelCheckpoint(filepath="log/"+"ep{epoch:03d}-loss{loss:.3f}-val_loss{val_loss:.3f}-f1_score{val_f1_score:.3f}.h5",monitor="val_loss",save_best_only=True, save_weights_only=True, mode="min")
  132. model.fit(x=[train_x[0],train_x[1]],y=train_y,batch_size=128,epochs=600,callbacks=[callback],validation_data=[[test_x[0],test_x[1]],test_y])
  133. def predict():
  134. '''
  135. @summary: 预测数据
  136. '''
  137. def getTokensLabels():
  138. conn = psycopg2.connect(dbname="BidiPro",user="postgres",password="postgres",host="192.168.2.101")
  139. cursor = conn.cursor()
  140. #sql = '''
  141. #SELECT s.tokens,e.begin_index,e.end_index,e.doc_id,e.entity_id,e.sentence_index,e.entity_text,e.entity_type from entity_mention e,sentences s
  142. #WHERE s.doc_id=e.doc_id AND s.sentence_index=e.sentence_index AND e.entity_id not in (SELECT entity_id from entity_label) and entity_type in ('person') limit 10000
  143. #'''
  144. sql = '''
  145. SELECT s.tokens,e.begin_index,e.end_index,e.doc_id,e.entity_id,e.sentence_index,e.entity_text,e.entity_type from entity_mention e,sentences s
  146. WHERE s.doc_id=e.doc_id AND s.sentence_index=e.sentence_index AND e.doc_id in(SELECT doc_id from articles_validation) and entity_type in ('person')
  147. '''
  148. cursor.execute(sql)
  149. print(sql)
  150. data_x = []
  151. doc_id = []
  152. ent_id = []
  153. sen = []
  154. ent_text = []
  155. dianhua = []
  156. rows = cursor.fetchmany(1000)
  157. key_word = re.compile('电话[:|:]\d{7,12}|联系方式[:|:]\d{7,12}')
  158. phone = re.compile('1[3|4|5|7|8][0-9][-|——|—]?\d{4}[-|——|—]?\d{4}|\d{3,4}[-|——|—]\d{7,8}/\d{3,8}|\d{3,4}[-|——|—]\d{7,8}转\d{1,4}|\d{3,4}[-|——|—]\d{7,8}|[\(|\(]0\d{2,3}[\)|\)]\d{7,8}') # 联系电话
  159. while(rows):
  160. for row in rows:
  161. item_x = embedding(spanWindow(tokens=row[0],begin_index=row[1],end_index=row[2]))
  162. s = spanWindow(tokens=row[0],begin_index=row[1],end_index=row[2],size=15)
  163. s2 = ''.join(s[1])
  164. s2 = re.sub(',)', '-', s2)
  165. s2 = re.sub('\s','',s2)
  166. have_key = re.findall(key_word, s2)
  167. have_phone = re.findall(phone, s2)
  168. if have_phone:
  169. dianhua.append(have_phone)
  170. elif have_key:
  171. dianhua.append(have_phone)
  172. else:
  173. dianhua.append('')
  174. sen.append(s2)
  175. ent_id.append(row[4])
  176. ent_text.append(row[6])
  177. data_x.append(item_x)
  178. doc_id.append(row[3])
  179. rows = cursor.fetchmany(1000)
  180. cursor.close()
  181. conn.close()
  182. return np.transpose(np.array(data_x),(1,0,2,3)),doc_id,ent_id,sen,ent_text,dianhua
  183. test_x,doc_id,ent_id,sen,ent_text,dianhua = getTokensLabels()
  184. model = models.load_model("model_person.model",custom_objects={'precision':precision,'recall':recall,'f1_score':f1_score})
  185. predict_y = model.predict([test_x[0],test_x[1]])
  186. label = [np.argmax(y) for y in predict_y]
  187. data = {'doc_id':doc_id, 'ent_id':ent_id, 'sen':sen, 'entity_text':ent_text, 'dianhua':dianhua, 'label':label}
  188. df = pd.DataFrame(data)
  189. df.to_excel('data/person_phone.xls')
  190. conn = psycopg2.connect(dbname='BidiPro', user='postgres',password='postgres',host='192.168.2.101')
  191. cursor = conn.cursor()
  192. table = 'person_phone_predict'
  193. cursor.execute(" select to_regclass('"+table+"') is null ")
  194. notExists = cursor.fetchall()[0][0]
  195. if notExists:
  196. cursor.execute(" create table "+table+" (doc_id text,entity_id text,entity text,label int,predict text,phone text)")
  197. else:
  198. cursor.execute(" delete from "+table)
  199. for i in range(len(df['ent_id'])):
  200. pre_y = [str(a) for a in predict_y[i]]
  201. sql = " insert into "+table+"(doc_id,entity_id,entity,label,predict,phone) values('"+str(df['doc_id'][i])+"','"+str(df['ent_id'][i])+"','"+str(df['entity_text'][i])+"',"+str(int(df['label'][i]))+",'"+str(','.join(pre_y))+"','"+str(','.join(df['dianhua'][i]))+"')"
  202. #print(sql)
  203. cursor.execute(sql)
  204. conn.commit()
  205. print('提交完成')
  206. cursor.close()
  207. conn.close()
  208. if __name__ == '__main__':
  209. #get_data()
  210. #label_data()
  211. #post_data()
  212. training()
  213. predict()
  214. # train()