model.py 2.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869
  1. import sys
  2. import os
  3. from keras.layers import Lambda
  4. sys.path.append(os.path.abspath("../.."))
  5. from keras import layers, models
  6. import keras.backend as K
  7. from BiddingKG.dl.table_head.models.my_average_pooling import MyAveragePooling1D
  8. from BiddingKG.dl.table_head.models.self_attention import SeqSelfAttention
  9. def get_model(input_shape, output_shape):
  10. # Input (batch, 10, 60)
  11. input_1 = layers.Input(shape=input_shape[1:], dtype="float32")
  12. input_2 = layers.Input(shape=input_shape[1:], dtype="float32")
  13. input_3 = layers.Input(shape=input_shape[1:], dtype="float32")
  14. input_4 = layers.Input(shape=input_shape[1:], dtype="float32")
  15. input_5 = layers.Input(shape=input_shape[1:], dtype="float32")
  16. input_6 = layers.Input(shape=input_shape[1:], dtype="float32")
  17. # ----------- Three box sequence -----------
  18. # Concat (batch, 30, 60)
  19. concat_1 = layers.concatenate([input_1, input_2, input_3], axis=-2, name='seq_concat')
  20. concat_2 = layers.concatenate([input_4, input_5, input_6], axis=-2)
  21. # Bi-LSTM (batch, 30, 128)
  22. bi_lstm_1 = layers.Bidirectional(layers.LSTM(64, return_sequences=True))(concat_1)
  23. bi_lstm_2 = layers.Bidirectional(layers.LSTM(64, return_sequences=True))(concat_2)
  24. # Self-Attention (batch, 30, 128)
  25. self_attention_1 = SeqSelfAttention(attention_activation='sigmoid')(bi_lstm_1)
  26. self_attention_2 = SeqSelfAttention(attention_activation='sigmoid')(bi_lstm_2)
  27. # Dense (batch, 30, 1)
  28. dense_1 = layers.Dense(output_shape[0], activation="relu")(self_attention_1)
  29. dense_2 = layers.Dense(output_shape[0], activation="relu")(self_attention_2)
  30. # Squeeze (batch, 30)
  31. squeeze_1 = Lambda(lambda x: K.squeeze(x, axis=-1))(dense_1)
  32. squeeze_2 = Lambda(lambda x: K.squeeze(x, axis=-1))(dense_2)
  33. # ----------- One box feature -----------
  34. # Bi-LSTM (batch, 10, 128)
  35. bi_lstm = layers.Bidirectional(layers.LSTM(64, return_sequences=True))(input_2)
  36. # Self-Attention (batch, 10, 128)
  37. self_attention = SeqSelfAttention(attention_activation='sigmoid')(bi_lstm)
  38. # mask mean pooling
  39. # pool_1 = MyAveragePooling1D(axis=-1)(self_attention_1)
  40. # Dense (batch, 10, 1)
  41. dense = layers.Dense(output_shape[0], activation="relu")(self_attention)
  42. # Squeeze (batch, 10) - one box feature
  43. squeeze = Lambda(lambda x: K.squeeze(x, axis=-1))(dense)
  44. # ----------- Three box sequence & One box feature -----------
  45. # Dense (batch, 1)
  46. concat = layers.concatenate([squeeze, squeeze_1, squeeze_2])
  47. output = layers.Dense(64, activation='relu')(concat)
  48. output = layers.Dense(1, activation="sigmoid", name='output')(output)
  49. model = models.Model(inputs=[input_1, input_2, input_3, input_4, input_5, input_6],
  50. outputs=output)
  51. model.summary()
  52. return model