model_2.py 2.6 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758
  1. import sys
  2. import os
  3. sys.path.append(os.path.abspath("../.."))
  4. from keras import layers, models
  5. import tensorflow as tf
  6. from BiddingKG.dl.table_head.models.my_average_pooling import MyAveragePooling1D
  7. from BiddingKG.dl.table_head.models.self_attention import SeqSelfAttention
  8. def get_model(input_shape, output_shape):
  9. # Input
  10. input_1 = layers.Input(shape=input_shape[1:], dtype="float32")
  11. input_2 = layers.Input(shape=input_shape[1:], dtype="float32")
  12. input_3 = layers.Input(shape=input_shape[1:], dtype="float32")
  13. input_4 = layers.Input(shape=input_shape[1:], dtype="float32")
  14. input_5 = layers.Input(shape=input_shape[1:], dtype="float32")
  15. input_6 = layers.Input(shape=input_shape[1:], dtype="float32")
  16. # Bi-LSTM
  17. bi_lstm_1 = layers.Bidirectional(layers.LSTM(16, return_sequences=True))(input_1)
  18. bi_lstm_2 = layers.Bidirectional(layers.LSTM(16, return_sequences=True))(input_2)
  19. bi_lstm_3 = layers.Bidirectional(layers.LSTM(16, return_sequences=True))(input_3)
  20. bi_lstm_4 = layers.Bidirectional(layers.LSTM(16, return_sequences=True))(input_4)
  21. bi_lstm_5 = layers.Bidirectional(layers.LSTM(16, return_sequences=True))(input_5)
  22. bi_lstm_6 = layers.Bidirectional(layers.LSTM(16, return_sequences=True))(input_6)
  23. # Self-Attention
  24. self_attention_1 = SeqSelfAttention(attention_activation='sigmoid')(bi_lstm_1)
  25. self_attention_2 = SeqSelfAttention(attention_activation='sigmoid')(bi_lstm_2)
  26. self_attention_3 = SeqSelfAttention(attention_activation='sigmoid')(bi_lstm_3)
  27. self_attention_4 = SeqSelfAttention(attention_activation='sigmoid')(bi_lstm_4)
  28. self_attention_5 = SeqSelfAttention(attention_activation='sigmoid')(bi_lstm_5)
  29. self_attention_6 = SeqSelfAttention(attention_activation='sigmoid')(bi_lstm_6)
  30. # Concat
  31. concat_1 = layers.concatenate([self_attention_1, self_attention_2, self_attention_3])
  32. concat_2 = layers.concatenate([self_attention_4, self_attention_5, self_attention_6])
  33. # Dense + Sigmoid
  34. dense_1 = layers.Dense(output_shape[0], activation="sigmoid")(concat_1)
  35. dense_2 = layers.Dense(output_shape[0], activation="sigmoid")(concat_2)
  36. # mask mean pooling
  37. pool_1 = MyAveragePooling1D(axis=1)(dense_1)
  38. pool_2 = MyAveragePooling1D(axis=1)(dense_2)
  39. # Concat
  40. concat = layers.concatenate([pool_1, pool_2])
  41. # Dense
  42. output = layers.Dense(10)(concat)
  43. output = layers.Dense(1, activation="sigmoid", name='output')(output)
  44. model = models.Model(inputs=[input_1, input_2, input_3, input_4, input_5, input_6],
  45. outputs=output)
  46. model.summary()
  47. return model