model.py 1.4 KB

123456789101112131415161718192021222324252627282930313233343536373839
  1. from keras import layers, models
  2. from keras.layers import Lambda
  3. from BiddingKG.dl.table_head.models.my_average_pooling import MyAveragePooling1D
  4. from BiddingKG.dl.table_head.models.self_attention import SeqSelfAttention
  5. def get_model(input_shape, output_shape):
  6. # Input
  7. input_1 = layers.Input(shape=input_shape[1:], dtype="float32")
  8. input_2 = layers.Input(shape=input_shape[1:], dtype="float32")
  9. # Embedding
  10. embed_1 = layers.Embedding(input_dim=6624, output_dim=32,
  11. input_length=input_shape[1], mask_zero=True)(input_1)
  12. embed_2 = layers.Embedding(input_dim=6624, output_dim=32,
  13. input_length=input_shape[1], mask_zero=True)(input_2)
  14. # Bi-LSTM
  15. bi_lstm_1 = layers.Bidirectional(layers.LSTM(16, return_sequences=True))(embed_1)
  16. bi_lstm_2 = layers.Bidirectional(layers.LSTM(16, return_sequences=True))(embed_2)
  17. # Self-Attention
  18. self_attention_1 = SeqSelfAttention(attention_activation='sigmoid')(bi_lstm_1)
  19. self_attention_2 = SeqSelfAttention(attention_activation='sigmoid')(bi_lstm_2)
  20. # Concat
  21. concat = layers.concatenate([self_attention_1, self_attention_2])
  22. # Dense + Softmax
  23. output = layers.Dense(output_shape[0], activation="softmax")(concat)
  24. # mask mean pooling
  25. output = MyAveragePooling1D(axis=1)(output)
  26. model = models.Model(inputs=[input_1, input_2], outputs=output)
  27. model.summary()
  28. return model