123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869 |
- import sys
- import os
- from keras.layers import Lambda
- sys.path.append(os.path.abspath("../.."))
- from keras import layers, models
- import keras.backend as K
- from BiddingKG.dl.table_head.models.my_average_pooling import MyAveragePooling1D
- from BiddingKG.dl.table_head.models.self_attention import SeqSelfAttention
- def get_model(input_shape, output_shape):
- # Input (batch, 10, 60)
- input_1 = layers.Input(shape=input_shape[1:], dtype="float32")
- input_2 = layers.Input(shape=input_shape[1:], dtype="float32")
- input_3 = layers.Input(shape=input_shape[1:], dtype="float32")
- input_4 = layers.Input(shape=input_shape[1:], dtype="float32")
- input_5 = layers.Input(shape=input_shape[1:], dtype="float32")
- input_6 = layers.Input(shape=input_shape[1:], dtype="float32")
- # ----------- Three box sequence -----------
- # Concat (batch, 30, 60)
- concat_1 = layers.concatenate([input_1, input_2, input_3], axis=-2, name='seq_concat')
- concat_2 = layers.concatenate([input_4, input_5, input_6], axis=-2)
- # Bi-LSTM (batch, 30, 128)
- bi_lstm_1 = layers.Bidirectional(layers.LSTM(64, return_sequences=True))(concat_1)
- bi_lstm_2 = layers.Bidirectional(layers.LSTM(64, return_sequences=True))(concat_2)
- # Self-Attention (batch, 30, 128)
- self_attention_1 = SeqSelfAttention(attention_activation='sigmoid')(bi_lstm_1)
- self_attention_2 = SeqSelfAttention(attention_activation='sigmoid')(bi_lstm_2)
- # Dense (batch, 30, 1)
- dense_1 = layers.Dense(output_shape[0], activation="relu")(self_attention_1)
- dense_2 = layers.Dense(output_shape[0], activation="relu")(self_attention_2)
- # Squeeze (batch, 30)
- squeeze_1 = Lambda(lambda x: K.squeeze(x, axis=-1))(dense_1)
- squeeze_2 = Lambda(lambda x: K.squeeze(x, axis=-1))(dense_2)
- # ----------- One box feature -----------
- # Bi-LSTM (batch, 10, 128)
- bi_lstm = layers.Bidirectional(layers.LSTM(64, return_sequences=True))(input_2)
- # Self-Attention (batch, 10, 128)
- self_attention = SeqSelfAttention(attention_activation='sigmoid')(bi_lstm)
- # mask mean pooling
- # pool_1 = MyAveragePooling1D(axis=-1)(self_attention_1)
- # Dense (batch, 10, 1)
- dense = layers.Dense(output_shape[0], activation="relu")(self_attention)
- # Squeeze (batch, 10) - one box feature
- squeeze = Lambda(lambda x: K.squeeze(x, axis=-1))(dense)
- # ----------- Three box sequence & One box feature -----------
- # Dense (batch, 1)
- concat = layers.concatenate([squeeze, squeeze_1, squeeze_2])
- output = layers.Dense(64, activation='relu')(concat)
- output = layers.Dense(1, activation="sigmoid", name='output')(output)
- model = models.Model(inputs=[input_1, input_2, input_3, input_4, input_5, input_6],
- outputs=output)
- model.summary()
- return model
|