|
@@ -73,6 +73,124 @@ def model_1(input_shape, output_shape):
|
|
|
return model
|
|
|
|
|
|
|
|
|
+def model_1_small(input_shape, output_shape):
|
|
|
+ # Input (batch, 10, 60)
|
|
|
+ input_1 = layers.Input(shape=input_shape[1:], dtype="float32")
|
|
|
+ input_2 = layers.Input(shape=input_shape[1:], dtype="float32")
|
|
|
+ input_3 = layers.Input(shape=input_shape[1:], dtype="float32")
|
|
|
+ input_4 = layers.Input(shape=input_shape[1:], dtype="float32")
|
|
|
+ input_5 = layers.Input(shape=input_shape[1:], dtype="float32")
|
|
|
+ input_6 = layers.Input(shape=input_shape[1:], dtype="float32")
|
|
|
+
|
|
|
+ # ----------- Three box sequence -----------
|
|
|
+ # Concat (batch, 30, 60)
|
|
|
+ concat_1 = layers.concatenate([input_1, input_2, input_3], axis=-2, name='seq_concat')
|
|
|
+ concat_2 = layers.concatenate([input_4, input_5, input_6], axis=-2)
|
|
|
+
|
|
|
+ # Bi-LSTM (batch, 30, 128)
|
|
|
+ bi_lstm_1 = layers.Bidirectional(layers.LSTM(32, return_sequences=True))(concat_1)
|
|
|
+ bi_lstm_2 = layers.Bidirectional(layers.LSTM(32, return_sequences=True))(concat_2)
|
|
|
+
|
|
|
+ # Self-Attention (batch, 30, 128)
|
|
|
+ self_attention_1 = SeqSelfAttention(attention_activation='sigmoid')(bi_lstm_1)
|
|
|
+ self_attention_2 = SeqSelfAttention(attention_activation='sigmoid')(bi_lstm_2)
|
|
|
+
|
|
|
+ # Dense (batch, 30, 1)
|
|
|
+ dense_1 = layers.Dense(output_shape[0], activation="relu")(self_attention_1)
|
|
|
+ dense_2 = layers.Dense(output_shape[0], activation="relu")(self_attention_2)
|
|
|
+
|
|
|
+ # Squeeze (batch, 30)
|
|
|
+ squeeze_1 = Lambda(lambda x: K.squeeze(x, axis=-1))(dense_1)
|
|
|
+ squeeze_2 = Lambda(lambda x: K.squeeze(x, axis=-1))(dense_2)
|
|
|
+
|
|
|
+ # ----------- One box feature -----------
|
|
|
+ # Bi-LSTM (batch, 10, 128)
|
|
|
+ bi_lstm = layers.Bidirectional(layers.LSTM(32, return_sequences=True))(input_2)
|
|
|
+
|
|
|
+ # Self-Attention (batch, 10, 128)
|
|
|
+ self_attention = SeqSelfAttention(attention_activation='sigmoid')(bi_lstm)
|
|
|
+
|
|
|
+ # mask mean pooling
|
|
|
+ # pool_1 = MyAveragePooling1D(axis=-1)(self_attention_1)
|
|
|
+
|
|
|
+ # Dense (batch, 10, 1)
|
|
|
+ dense = layers.Dense(output_shape[0], activation="relu")(self_attention)
|
|
|
+
|
|
|
+ # Squeeze (batch, 10) - one box feature
|
|
|
+ squeeze = Lambda(lambda x: K.squeeze(x, axis=-1))(dense)
|
|
|
+
|
|
|
+ # ----------- Three box sequence & One box feature -----------
|
|
|
+ # Dense (batch, 1)
|
|
|
+ concat = layers.concatenate([squeeze, squeeze_1, squeeze_2])
|
|
|
+ output = layers.Dense(32, activation='relu')(concat)
|
|
|
+ output = layers.Dense(1, activation="sigmoid", name='output')(output)
|
|
|
+
|
|
|
+ model = models.Model(inputs=[input_1, input_2, input_3, input_4, input_5, input_6],
|
|
|
+ outputs=output)
|
|
|
+
|
|
|
+ # model.summary()
|
|
|
+ return model
|
|
|
+
|
|
|
+
|
|
|
+def model_1_tiny(input_shape, output_shape):
|
|
|
+ # Input (batch, 10, 60)
|
|
|
+ input_1 = layers.Input(shape=input_shape[1:], dtype="float32")
|
|
|
+ input_2 = layers.Input(shape=input_shape[1:], dtype="float32")
|
|
|
+ input_3 = layers.Input(shape=input_shape[1:], dtype="float32")
|
|
|
+ input_4 = layers.Input(shape=input_shape[1:], dtype="float32")
|
|
|
+ input_5 = layers.Input(shape=input_shape[1:], dtype="float32")
|
|
|
+ input_6 = layers.Input(shape=input_shape[1:], dtype="float32")
|
|
|
+
|
|
|
+ # ----------- Three box sequence -----------
|
|
|
+ # Concat (batch, 30, 60)
|
|
|
+ concat_1 = layers.concatenate([input_1, input_2, input_3], axis=-2, name='seq_concat')
|
|
|
+ concat_2 = layers.concatenate([input_4, input_5, input_6], axis=-2)
|
|
|
+
|
|
|
+ # Bi-LSTM (batch, 30, 128)
|
|
|
+ bi_lstm_1 = layers.Bidirectional(layers.LSTM(16, return_sequences=True))(concat_1)
|
|
|
+ bi_lstm_2 = layers.Bidirectional(layers.LSTM(16, return_sequences=True))(concat_2)
|
|
|
+
|
|
|
+ # Self-Attention (batch, 30, 128)
|
|
|
+ self_attention_1 = SeqSelfAttention(attention_activation='sigmoid')(bi_lstm_1)
|
|
|
+ self_attention_2 = SeqSelfAttention(attention_activation='sigmoid')(bi_lstm_2)
|
|
|
+
|
|
|
+ # Dense (batch, 30, 1)
|
|
|
+ dense_1 = layers.Dense(output_shape[0], activation="relu")(self_attention_1)
|
|
|
+ dense_2 = layers.Dense(output_shape[0], activation="relu")(self_attention_2)
|
|
|
+
|
|
|
+ # Squeeze (batch, 30)
|
|
|
+ squeeze_1 = Lambda(lambda x: K.squeeze(x, axis=-1))(dense_1)
|
|
|
+ squeeze_2 = Lambda(lambda x: K.squeeze(x, axis=-1))(dense_2)
|
|
|
+
|
|
|
+ # ----------- One box feature -----------
|
|
|
+ # Bi-LSTM (batch, 10, 128)
|
|
|
+ bi_lstm = layers.Bidirectional(layers.LSTM(16, return_sequences=True))(input_2)
|
|
|
+
|
|
|
+ # Self-Attention (batch, 10, 128)
|
|
|
+ self_attention = SeqSelfAttention(attention_activation='sigmoid')(bi_lstm)
|
|
|
+
|
|
|
+ # mask mean pooling
|
|
|
+ # pool_1 = MyAveragePooling1D(axis=-1)(self_attention_1)
|
|
|
+
|
|
|
+ # Dense (batch, 10, 1)
|
|
|
+ dense = layers.Dense(output_shape[0], activation="relu")(self_attention)
|
|
|
+
|
|
|
+ # Squeeze (batch, 10) - one box feature
|
|
|
+ squeeze = Lambda(lambda x: K.squeeze(x, axis=-1))(dense)
|
|
|
+
|
|
|
+ # ----------- Three box sequence & One box feature -----------
|
|
|
+ # Dense (batch, 1)
|
|
|
+ concat = layers.concatenate([squeeze, squeeze_1, squeeze_2])
|
|
|
+ output = layers.Dense(16, activation='relu')(concat)
|
|
|
+ output = layers.Dense(1, activation="sigmoid", name='output')(output)
|
|
|
+
|
|
|
+ model = models.Model(inputs=[input_1, input_2, input_3, input_4, input_5, input_6],
|
|
|
+ outputs=output)
|
|
|
+
|
|
|
+ # model.summary()
|
|
|
+ return model
|
|
|
+
|
|
|
+
|
|
|
def model_2(input_shape, output_shape):
|
|
|
# input_shape = (None, None, 10, 60)
|
|
|
# (batch_size, row_num, col_num, character_num, character_embedding)
|
|
@@ -266,7 +384,7 @@ def model_3(input_shape, output_shape):
|
|
|
|
|
|
def get_model(input_shape, output_shape, model_id):
|
|
|
if model_id == 1:
|
|
|
- return model_1(input_shape, output_shape)
|
|
|
+ return model_1_tiny(input_shape, output_shape)
|
|
|
elif model_id == 2:
|
|
|
return model_2(input_shape, output_shape)
|
|
|
elif model_id == 3:
|