main.py 6.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128
  1. # encoding=utf-8
  2. import pickle
  3. from data_precess import get_train_test_data
  4. from data_util import precision, recall, f1_score, get_remove_word, get_embedding, get_label, get_dic, clean_word_with_tokenizer, model_in
  5. import keras.backend as K
  6. from keras.layers import Input, Embedding, Bidirectional, GRU, Dropout, Dense
  7. from keras.models import Model
  8. from keras import models
  9. from keras.callbacks import ModelCheckpoint
  10. from keras.engine.topology import Layer
  11. class Attention(Layer):
  12. def __init__(self, **kwargs):
  13. super(Attention, self).__init__(**kwargs)
  14. def build(self, input_shape):
  15. # W: (EMBED_SIZE, 1)
  16. # b: (MAX_TIMESTEPS, 1)
  17. # u: (MAX_TIMESTEPS, MAX_TIMESTEPS)
  18. self.W = self.add_weight(name="W_{:s}".format(self.name),
  19. shape=(input_shape[-1], 1),
  20. initializer="normal")
  21. self.b = self.add_weight(name="b_{:s}".format(self.name),
  22. shape=(input_shape[1], 1),
  23. initializer="zeros")
  24. self.u = self.add_weight(name="u_{:s}".format(self.name),
  25. shape=(input_shape[1], input_shape[1]),
  26. initializer="normal")
  27. super(Attention, self).build(input_shape)
  28. def call(self, x, mask=None):
  29. # input: (BATCH_SIZE, MAX_TIMESTEPS, EMBED_SIZE)
  30. # et: (BATCH_SIZE, MAX_TIMESTEPS)
  31. et = K.squeeze(K.tanh(K.dot(x, self.W) + self.b), axis=-1)
  32. # at: (BATCH_SIZE, MAX_TIMESTEPS)
  33. at = K.dot(et, self.u)
  34. at = K.exp(at)
  35. if mask is not None:
  36. at *= K.cast(mask, K.floatx())
  37. # ot: (BATCH_SIZE, MAX_TIMESTEPS, EMBED_SIZE)
  38. at /= K.cast(K.sum(at, axis=1, keepdims=True) + K.epsilon(), K.floatx())
  39. atx = K.expand_dims(at, axis=-1)
  40. ot = atx * x
  41. # output: (BATCH_SIZE, EMBED_SIZE)
  42. return K.sum(ot, axis=1)
  43. def compute_mask(self, input, input_mask=None):
  44. # do not pass the mask to the next layers
  45. return None
  46. def compute_output_shape(self, input_shape):
  47. # output shape: (BATCH_SIZE, EMBED_SIZE)
  48. return (input_shape[0], input_shape[-1])
  49. def get_config(self):
  50. return super(Attention, self).get_config()
  51. def bigru_attention_softmax(input_size, word_index, embedding_matrix, classes):
  52. sent_inputs = Input(shape=(input_size,), dtype="float64")
  53. sent_emb = Embedding(input_dim=len(word_index) + 1,
  54. output_dim=100,
  55. mask_zero=True,
  56. weights=[embedding_matrix])(sent_inputs)
  57. sent_enc = Bidirectional(GRU(128, dropout=0.2, recurrent_dropout=0.2,
  58. return_sequences=True))(sent_emb)
  59. embeddings = Dropout(0.2)(sent_enc)
  60. sent_att1 = Attention()(embeddings)
  61. fc1_dropout = Dropout(0.2)(sent_att1)
  62. fc1 = Dense(422, activation="relu")(fc1_dropout)
  63. fc2_dropout = Dropout(0.2)(fc1)
  64. sent_pred = Dense(classes, activation="softmax")(fc2_dropout)
  65. model = Model(inputs=sent_inputs, outputs=sent_pred)
  66. model.compile(loss='categorical_crossentropy',
  67. optimizer='adam',
  68. metrics=[precision, recall, f1_score])
  69. model.summary()
  70. return model
  71. def bigru_attention_softmax_weights(input_size, word_index, embedding_matrix, labels, weight):
  72. model_gru_attention = bigru_attention_softmax(input_size, word_index, embedding_matrix, labels)
  73. model_gru_attention.load_weights(weight)
  74. return model_gru_attention
  75. def train():
  76. with open('padded_sequences.pkl', 'rb') as f:
  77. padded_sequences = pickle.load(f) # 加载词向量化后的训练集x
  78. with open('labels_np.pkl', 'rb') as f:
  79. labels_np = pickle.load(f) # 加载one_hot后的训练集y
  80. with open('padded_sequences_te.pkl', 'rb') as f:
  81. padded_sequences_te = pickle.load(f) # 加载词向量化后的测试集x
  82. with open('test_label.pkl', 'rb') as f:
  83. test_label = pickle.load(f) # 加载one_hot后的测试集y
  84. with open('word_index.pkl', 'rb') as f:
  85. word_index = pickle.load(f) # 加载词典 word:id
  86. with open('embedding_matrix.pkl', 'rb') as f:
  87. embedding_matrix = pickle.load(f) # 加载词向量矩阵
  88. #padded_sequences, labels_np, padded_sequences_te, test_label, word_index, embedding_matrix = get_train_test_data()
  89. checkpoint_gru_attention = ModelCheckpoint('model/New_attentionLSTM_weights1_100_em21.h5', monitor="val_f1_score",
  90. verbose=1, save_best_only=True, mode='max')
  91. model_gru_attention = bigru_attention_softmax(100, word_index, embedding_matrix, 211)
  92. model_gru_attention.fit(padded_sequences, labels_np,
  93. callbacks=[checkpoint_gru_attention], shuffle=True,
  94. validation_data=(padded_sequences_te, test_label),
  95. epochs=25, batch_size=1024) # batch_size 128
  96. def test():
  97. remove_word = get_remove_word() # 加载停用词、不重要的词
  98. word_index, tokenizer, embedding_matrix = get_embedding() # 加载文件,返回词典、keras tokennizer对象,词向量矩阵
  99. label_mapping, labels = get_label() #加载标签字典,返回字典label_mapping {0: '安防系统', 1: '安全保护服务', 2: '安全保护设备' ; labels10 所有类别的中文名称
  100. labels = 211 # 分类数
  101. gru_weights = 'model/New_attentionLSTM_weights1_100_em21.h5'
  102. model_gru_attention = bigru_attention_softmax_weights(100, word_index, embedding_matrix, labels, gru_weights)
  103. file = 'F:\\工作文档\\近义词\\text_zhaobiao\\比地_1 畜禽养殖工程_2.txt'
  104. with open(file, 'r', encoding='utf-8') as f:
  105. k = file
  106. content = f.read()
  107. q = {"id": k, "content": content}
  108. ContentIDs = []
  109. ContentIDs.append([q['id'], q['content']])
  110. x_train, id = clean_word_with_tokenizer(ContentIDs, remove_word, tokenizer)
  111. gru_te = model_gru_attention.predict(x_train)
  112. union = model_in(gru_te, label_mapping, id)
  113. print(union)
  114. if __name__ == '__main__':
  115. train()
  116. #test()