train.py 2.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899
  1. '''
  2. Created on 2019年4月15日
  3. @author: User
  4. '''
  5. import os
  6. import sys
  7. sys.path.append(os.path.abspath("../.."))
  8. import pandas as pd
  9. import gensim
  10. import numpy as np
  11. import math
  12. import models
  13. from keras.callbacks import ModelCheckpoint
  14. from BiddingKG.dl.common.Utils import *
  15. def embedding(datas,shape):
  16. '''
  17. @summary:查找词汇对应的词向量
  18. @param:
  19. datas:词汇的list
  20. shape:结果的shape
  21. @return: array,返回对应shape的词嵌入
  22. '''
  23. model_w2v = getModel_word()
  24. embed = np.zeros(shape)
  25. length = shape[1]
  26. out_index = 0
  27. #print(datas)
  28. for data in datas:
  29. index = 0
  30. for item in str(data)[-shape[1]:]:
  31. if index>=length:
  32. break
  33. if item in model_w2v.vocab:
  34. embed[out_index][index] = model_w2v[item]
  35. index += 1
  36. else:
  37. #embed[out_index][index] = model_w2v['unk']
  38. index += 1
  39. out_index += 1
  40. return embed
  41. def labeling(label,out_len=2):
  42. out = np.zeros((out_len))
  43. out[label] = 1
  44. return out
  45. def getTrainData(percent=0.9):
  46. train_x = []
  47. train_y = []
  48. test_x = []
  49. test_y = []
  50. files = ["批量.xls","剩余手工标注.xls"]
  51. for file in files:
  52. df = pd.read_excel(file)
  53. for before,text,after,label in zip(df["list_before"],df["list_code"],df["list_after"],df["list_label"]):
  54. the_label = 0
  55. if not math.isnan(label):
  56. the_label = int(label)
  57. if the_label not in [0,1]:
  58. print(after,text)
  59. continue
  60. x = embedding([before,text,after],shape=(3,40,60))
  61. y = labeling(the_label)
  62. if np.random.random()<percent:
  63. train_x.append(x)
  64. train_y.append(y)
  65. else:
  66. test_x.append(x)
  67. test_y.append(y)
  68. return np.transpose(np.array(train_x),(1,0,2,3)),np.array(train_y),np.transpose(np.array(test_x),(1,0,2,3)),np.array(test_y)
  69. def train():
  70. #train_x,train_y,test_x,test_y = getTrainData()
  71. #save((train_x,train_y,test_x,test_y),"data.pk")
  72. train_x,train_y,test_x,test_y = load("data.pk")
  73. model = models.getTextCNNModel()
  74. # model.load_weights("log/ep012-loss0.049-val_loss0.071-f1_score0.979.h5")
  75. callback = ModelCheckpoint(filepath="log/"+"ep{epoch:03d}-loss{loss:.3f}-val_loss{val_loss:.3f}-f1_score{val_f1_score:.3f}.h5",monitor="val_loss",save_best_only=True, save_weights_only=True, mode="min")
  76. model.fit(x=[train_x[0],train_x[1],train_x[2]],y=train_y,batch_size=24,epochs=400,callbacks=[callback],validation_data=[[test_x[0],test_x[1],test_x[2]],test_y])
  77. def test():
  78. model = models.getTextCNNModel()
  79. model.load_weights("models/model_code.h5")
  80. model.save("model_code.h5")
  81. if __name__=="__main__":
  82. train()
  83. #test()