train_2.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397
  1. import sys
  2. import os
  3. sys.path.append(os.path.abspath("../.."))
  4. import pandas as pd
  5. import re
  6. import psycopg2
  7. from keras.callbacks import ModelCheckpoint
  8. from keras import layers,models,optimizers,losses
  9. from BiddingKG.dl.common.Utils import *
  10. from BiddingKG.dl.common.models import *
  11. from sklearn.metrics import classification_report
  12. from sklearn.utils import shuffle,class_weight
  13. import matplotlib.pyplot as plt
  14. input_shape = (2,30,60)
  15. input_shape2 = (2,10,128)
  16. output_shape = [4]
  17. def get_data():
  18. data_load = pd.read_csv("C:\\Users\\admin\\Desktop\\newdata_30_prc.csv", index_col=0)
  19. id_set = set()
  20. for id in data_load['document_id']:
  21. id_set.add(id)
  22. conn = psycopg2.connect(dbname="iepy", user="postgres", password="postgres", host="192.168.2.101")
  23. sql = "SELECT A.human_identifier,A.sentences,A.tokens,A.offsets_to_text,B.value " \
  24. "FROM corpus_iedocument A,brat_bratannotation B " \
  25. "WHERE A.human_identifier = '%s' " \
  26. "AND A.human_identifier = B.document_id "
  27. db_data = []
  28. count = 0
  29. for id in list(id_set):
  30. count+=1
  31. print(count)
  32. cur1 = conn.cursor()
  33. cur1.execute(sql % (id))
  34. db_data.extend(cur1.fetchall())
  35. cur1.close()
  36. conn.close()
  37. columns = ['document_id','sentences','tokens','offsets_to_text','value']
  38. df = pd.DataFrame(db_data, columns=columns)
  39. df = df[df['value'].str.contains('time')]
  40. df = df.reset_index(drop=True)
  41. print(len(df))
  42. time_label = df['value'].str.split(expand=True)
  43. time_label.columns = ['_', 'label_type', 'begin_index', 'end_index', 'entity_text']
  44. time_label = time_label.drop('_', axis=1)
  45. df = pd.concat([df, time_label], axis=1)
  46. print(df.info())
  47. df['tokens'] = [token[2:-2].split("', '") for token in df['tokens']]
  48. df['sentences'] = [sentence[1:-1].split(", ") for sentence in df['sentences']]
  49. df['sentences'] = [[int(s) for s in sentence] for sentence in df['sentences']]
  50. df['offsets_to_text'] = [offset[1:-1].split(", ") for offset in df['offsets_to_text']]
  51. df['offsets_to_text'] = [[int(o) for o in offset] for offset in df['offsets_to_text']]
  52. save(df,'db_time_data.pk')
  53. def getModel():
  54. '''
  55. @summary: 时间分类模型
  56. '''
  57. L_input = layers.Input(shape=input_shape2[1:], dtype='float32')
  58. R_input = layers.Input(shape=input_shape2[1:], dtype='float32')
  59. L_lstm = layers.Bidirectional(layers.LSTM(40,return_sequences=True,dropout=0.1))(L_input)
  60. # L_lstm = layers.LSTM(32,return_sequences=True,dropout=0.2)(L_input)
  61. avg_l = layers.GlobalAveragePooling1D()(L_lstm)
  62. R_lstm = layers.Bidirectional(layers.LSTM(40,return_sequences=True,dropout=0.1))(R_input)
  63. # R_lstm = layers.LSTM(32, return_sequences=True, dropout=0.2)(R_input)
  64. avg_r = layers.GlobalAveragePooling1D()(R_lstm)
  65. concat = layers.merge([avg_l, avg_r], mode='concat')
  66. # lstm = layers.LSTM(24,return_sequences=False,dropout=0.2)(concat)
  67. output = layers.Dense(output_shape[0],activation="softmax")(concat)
  68. model = models.Model(inputs=[L_input,R_input], outputs=output)
  69. learn_rate = 0.0005
  70. model.compile(optimizer=optimizers.Adam(lr=learn_rate),
  71. loss=losses.binary_crossentropy,
  72. metrics=[precision,recall,f1_score])
  73. model.summary()
  74. return model
  75. def training():
  76. data_load = pd.read_csv("C:\\Users\\admin\\Desktop\\newdata_30_prc.csv", index_col=0)
  77. data_load = data_load.reset_index(drop=True)
  78. test_data = data_load.sample(frac=0.2, random_state=8)
  79. train_data = data_load.drop(test_data.index, axis=0)
  80. train_data =train_data.reset_index(drop=True)
  81. train_x = []
  82. train_y = []
  83. for left, right, label in zip(train_data['context_left'], train_data['context_right'], train_data['re_label']):
  84. y = np.zeros(output_shape)
  85. y[label] = 1
  86. left = str(left)
  87. right = str(right)
  88. if left=='nan': left = ''
  89. if right=='nan': right = ''
  90. left = list(left)
  91. right = list(right)
  92. context = [left, right]
  93. x = embedding_word(context, shape=input_shape)
  94. train_x.append(x)
  95. train_y.append(y)
  96. test_x = []
  97. test_y = []
  98. for left, right, label in zip(test_data['context_left'], test_data['context_right'], test_data['re_label']):
  99. y = np.zeros(output_shape)
  100. y[label] = 1
  101. left = str(left)
  102. right = str(right)
  103. if left == 'nan': left = ''
  104. if right == 'nan': right = ''
  105. left = list(left)
  106. right = list(right)
  107. context = [left, right]
  108. x = embedding_word(context, shape=input_shape)
  109. test_x.append(x)
  110. test_y.append(y)
  111. train_y, test_y = (np.array(train_y), np.array(test_y))
  112. train_x, test_x = (np.array(train_x), np.array(test_x))
  113. train_x, test_x = (np.transpose(train_x, (1, 0, 2, 3)), np.transpose(test_x, (1, 0, 2, 3)))
  114. model = getModel()
  115. epochs = 150
  116. batch_size = 256
  117. checkpoint = ModelCheckpoint("model_label_time_classify.model.hdf5", monitor="val_loss", verbose=1,
  118. save_best_only=True, mode='min')
  119. # cw = class_weight.compute_class_weight('auto',np.unique(np.argmax(train_y,axis=1)),np.argmax(train_y,axis=1))
  120. # cw = dict(enumerate(cw))
  121. history = model.fit(
  122. x=[train_x[0], train_x[1]],
  123. y=train_y,
  124. validation_data=([test_x[0], test_x[1]], test_y),
  125. epochs=epochs,
  126. batch_size=batch_size,
  127. shuffle=True,
  128. callbacks=[checkpoint],
  129. class_weight='auto'
  130. )
  131. # plot_loss(history=history)
  132. load_model = models.load_model("model_label_time_classify.model.hdf5",
  133. custom_objects={'precision': precision, 'recall': recall, 'f1_score': f1_score})
  134. y_pre = load_model.predict([test_x[0], test_x[1]])
  135. # y_pre = load_model.predict(test_x[0])
  136. # 各类别预测评估
  137. res1 = classification_report(np.argmax(test_y, axis=1), np.argmax(y_pre, axis=1))
  138. print(res1)
  139. y_pre2 = load_model.predict([train_x[0], train_x[1]])
  140. # y_pre2 = load_model.predict(train_x[0])
  141. res2 = classification_report(np.argmax(train_y, axis=1), np.argmax(y_pre2, axis=1))
  142. print(res2)
  143. def train2():
  144. data_load = pd.read_csv("C:\\Users\\admin\\Desktop\\tokens_data.csv", index_col=0)
  145. data_load = data_load.reset_index(drop=True)
  146. data_load['context_left'] = [left[2:-2].split("', '") for left in data_load['context_left']]
  147. data_load['context_right'] = [right[2:-2].split("', '") for right in data_load['context_right']]
  148. test_data = data_load.sample(frac=0.2, random_state=8)
  149. train_data = data_load.drop(test_data.index, axis=0)
  150. train_data =train_data.reset_index(drop=True)
  151. train_x = []
  152. train_y = []
  153. for left, right, label in zip(train_data['context_left'], train_data['context_right'], train_data['label']):
  154. y = np.zeros(output_shape)
  155. y[label] = 1
  156. context = [left, right]
  157. x = embedding(context, shape=input_shape2)
  158. train_x.append(x)
  159. train_y.append(y)
  160. test_x = []
  161. test_y = []
  162. for left, right, label in zip(test_data['context_left'], test_data['context_right'], test_data['label']):
  163. y = np.zeros(output_shape)
  164. y[label] = 1
  165. context = [left, right]
  166. x = embedding(context, shape=input_shape2)
  167. test_x.append(x)
  168. test_y.append(y)
  169. train_y, test_y = (np.array(train_y), np.array(test_y))
  170. train_x, test_x = (np.array(train_x), np.array(test_x))
  171. train_x, test_x = (np.transpose(train_x, (1, 0, 2, 3)), np.transpose(test_x, (1, 0, 2, 3)))
  172. model = getModel()
  173. epochs = 150
  174. batch_size = 256
  175. checkpoint = ModelCheckpoint("model_label_time_classify.model.hdf5", monitor="val_loss", verbose=1,
  176. save_best_only=True, mode='min')
  177. # cw = class_weight.compute_class_weight('auto',np.unique(np.argmax(train_y,axis=1)),np.argmax(train_y,axis=1))
  178. # cw = dict(enumerate(cw))
  179. history = model.fit(
  180. x=[train_x[0], train_x[1]],
  181. y=train_y,
  182. validation_data=([test_x[0], test_x[1]], test_y),
  183. epochs=epochs,
  184. batch_size=batch_size,
  185. shuffle=True,
  186. callbacks=[checkpoint],
  187. class_weight='auto'
  188. )
  189. # plot_loss(history=history)
  190. load_model = models.load_model("model_label_time_classify.model.hdf5",
  191. custom_objects={'precision': precision, 'recall': recall, 'f1_score': f1_score})
  192. y_pre = load_model.predict([test_x[0], test_x[1]])
  193. # y_pre = load_model.predict(test_x[0])
  194. # 各类别预测评估
  195. res1 = classification_report(np.argmax(test_y, axis=1), np.argmax(y_pre, axis=1))
  196. print(res1)
  197. y_pre2 = load_model.predict([train_x[0], train_x[1]])
  198. # y_pre2 = load_model.predict(train_x[0])
  199. res2 = classification_report(np.argmax(train_y, axis=1), np.argmax(y_pre2, axis=1))
  200. print(res2)
  201. def predict2():
  202. model1 = models.load_model("model_label_time_classify.model.hdf5",custom_objects={'precision':precision,'recall':recall,'f1_score':f1_score})
  203. data_load = pd.read_csv("C:\\Users\\admin\\Desktop\\tokens_data.csv", index_col=0)
  204. data_load['context_left'] = [left[2:-2].split("', '") for left in data_load['context_left']]
  205. data_load['context_right'] = [right[2:-2].split("', '") for right in data_load['context_right']]
  206. test_x = []
  207. test_y = []
  208. for left, right, label in zip(data_load['context_left'], data_load['context_right'], data_load['label']):
  209. y = np.zeros(output_shape)
  210. y[label] = 1
  211. context = [left, right]
  212. x = embedding(context, shape=input_shape2)
  213. test_x.append(x)
  214. test_y.append(y)
  215. test_x = np.transpose(np.array(test_x), (1, 0, 2, 3))
  216. pre_y = model1.predict([test_x[0],test_x[1]])
  217. data_load['pre'] = [np.argmax(item) for item in pre_y]
  218. error_data = data_load[data_load['label']!=data_load['pre']]
  219. # print(error_data.info())
  220. error_data.to_csv("C:\\Users\\admin\\Desktop\\error4-30.csv")
  221. def predict():
  222. model1 = models.load_model("model_label_time_classify.model.hdf5",custom_objects={'precision':precision,'recall':recall,'f1_score':f1_score})
  223. data_load = pd.read_csv("C:\\Users\\admin\\Desktop\\newdata_30_prc.csv", index_col=0)
  224. test_x = []
  225. test_y = []
  226. for left, right, label in zip(data_load['context_left'], data_load['context_right'], data_load['re_label']):
  227. y = np.zeros(output_shape)
  228. y[label] = 1
  229. left = str(left)
  230. right = str(right)
  231. if left == 'nan': left = ''
  232. if right == 'nan': right = ''
  233. left = list(left)
  234. right = list(right)
  235. context = [left, right]
  236. x = embedding_word(context, shape=input_shape)
  237. test_x.append(x)
  238. test_y.append(y)
  239. test_x = np.transpose(np.array(test_x), (1, 0, 2, 3))
  240. pre_y = model1.predict([test_x[0],test_x[1]])
  241. data_load['pre'] = [np.argmax(item) for item in pre_y]
  242. error_data = data_load[data_load['re_label']!=data_load['pre']]
  243. # print(error_data.info())
  244. error_data.to_csv("C:\\Users\\admin\\Desktop\\error4-30.csv")
  245. def data_process():
  246. data_load = pd.read_csv("C:\\Users\\admin\\Desktop\\newdata_30.csv", index_col=0)
  247. re_left = re.compile("。[^。]*?$")
  248. re_right = re.compile("^[^。]*?。")
  249. left_list = []
  250. right_list = []
  251. for left, right in zip(data_load['context_left'], data_load['context_right']):
  252. left = str(left)
  253. right = str(right)
  254. if right=='nan':
  255. right = ''
  256. # print(1)
  257. if re.search("。",left):
  258. left = re_left.search(left)
  259. left = left.group()[1:]
  260. if re.search("。",right):
  261. right = re_right.search(right)
  262. right = right.group()
  263. left_list.append(left)
  264. right_list.append(right)
  265. data_load['context_left'] = left_list
  266. data_load['context_right'] = right_list
  267. data_load.to_csv("C:\\Users\\admin\\Desktop\\newdata_30_prc.csv")
  268. def data_process2():
  269. data_load = pd.read_csv("C:\\Users\\admin\\Desktop\\newdata_30_prc.csv", index_col=0)
  270. left_list = []
  271. right_list = []
  272. for left, right in zip(data_load['context_left'], data_load['context_right']):
  273. left = str(left)
  274. right = str(right)
  275. if right=='nan':
  276. right = ''
  277. if left=='nan':
  278. left = ''
  279. left = left[max(len(left)-20,0):]
  280. right = right[:20]
  281. left_list.append(left)
  282. right_list.append(right)
  283. data_load['context_left'] = left_list
  284. data_load['context_right'] = right_list
  285. data_load.to_csv("C:\\Users\\admin\\Desktop\\newdata_20_prc.csv")
  286. def data_process3():
  287. data = load('db_time_data.pk')
  288. data = data.drop('value', axis=1)
  289. token_begin = []
  290. token_end = []
  291. context_left = []
  292. context_right = []
  293. data2 = pd.read_csv("C:\\Users\\admin\\Desktop\\newdata_30_prc2.csv")
  294. label = []
  295. # data=data[:20]
  296. for id,sentences,tokens,offset,begin,end,entity_text in zip(data['document_id'],data['sentences'],data['tokens'],data['offsets_to_text'],
  297. data['begin_index'],data['end_index'],data['entity_text']):
  298. _label = data2[(data2['document_id']==int(id)) & (data2['begin_index']==int(begin))][:1]
  299. if not _label.empty:
  300. _label = int(_label['re_label'])
  301. else:
  302. _label=0
  303. label.append(_label)
  304. begin = int(begin)
  305. end = int(end)
  306. entity_tbegin = 0
  307. entity_tend = 0
  308. find_begin = False
  309. for t in range(len(offset)):
  310. if not find_begin:
  311. if offset[t]==begin:
  312. entity_tbegin = t
  313. find_begin = True
  314. if offset[t]>begin:
  315. entity_tbegin = t-1
  316. find_begin = True
  317. if offset[t] >= end:
  318. entity_tend = t
  319. break
  320. token_begin.append(entity_tbegin)
  321. token_end.append(entity_tend)
  322. s = spanWindow(tokens=tokens,begin_index=entity_tbegin,end_index=entity_tend,size=10)
  323. s1 = s[0]
  324. _temp1 = []
  325. for i in range(len(s1)):
  326. if s1[i]=="。":
  327. _temp1.append(i)
  328. if _temp1:
  329. s1 = s1[_temp1[-1]+1:]
  330. s2 = s[1]
  331. _temp2 = []
  332. for i in range(len(s2)):
  333. if s2[i] == "。":
  334. _temp2.append(i)
  335. break
  336. if _temp2:
  337. s2 = s2[:_temp2[0]+1]
  338. # print(s2)
  339. context_left.append(s1)
  340. context_right.append(s2)
  341. print(id)
  342. # print(_label)
  343. # print(entity_text)
  344. # print(tokens[entity_tbegin:entity_tend])
  345. data['token_begin'] = token_begin
  346. data['token_end'] = token_end
  347. data['context_left'] = context_left
  348. data['context_right'] = context_right
  349. data['label'] = label
  350. data = data.drop(['tokens','offsets_to_text','sentences'],axis=1)
  351. data.to_csv("C:\\Users\\admin\\Desktop\\tokens_data.csv")
  352. def plot_loss(history):
  353. plt.plot(history.history['loss'])
  354. plt.plot(history.history['val_loss'])
  355. plt.title('Model loss')
  356. plt.ylabel('Loss')
  357. plt.xlabel('Epoch')
  358. plt.legend(['Train', 'Test'], loc='upper left')
  359. plt.show()
  360. if __name__ == '__main__':
  361. # get_data()
  362. # getModel()
  363. # training()
  364. # train2()
  365. # data_process()
  366. # data_process2()
  367. # data_process3()
  368. # predict()
  369. # predict2()
  370. pass