import keras import tensorflow as tf from keras import models, backend as K from keras.layers import Layer, Input, Lambda, Concatenate, Dense, LSTM, Bidirectional from tensorflow.contrib.rnn import LSTMCell import numpy as np from BiddingKG.dl.table_head.models.self_attention import SeqSelfAttention from BiddingKG.dl.table_head.models.u_net import u_net_small def attention(inputs, w_omega, b_omega, u_omega, time_major=False): if isinstance(inputs, tuple): inputs = tf.concat(inputs, 2) if time_major: # (B,T,D) => (T,B,D) inputs = tf.transpose(inputs, [1, 0, 2]) v = tf.tanh(tf.tensordot(inputs, w_omega, axes=1) + b_omega) vu = tf.tensordot(v, u_omega, axes=1, name='vu') # (B,T) shape alphas = tf.nn.softmax(vu, name='alphas') # (B,T) shape # the result has (B,D) shape output = tf.reduce_sum(inputs * tf.expand_dims(alphas, -1), 1) return output, alphas class LoopCell(Layer): def __init__(self, hidden_size, attention_size, character_num, character_embed, cell_embed): super(LoopCell, self).__init__() # Hyper parameters self.hidden_size = hidden_size self.attention_size = attention_size self.character_num = character_num self.character_embed = character_embed self.cell_embed = cell_embed def build(self, batch_input_shape): super(LoopCell, self).build(batch_input_shape) # Trainable parameters # Attention # self.w_omega = self.add_weight("w_omega", shape=[self.hidden_size*2, self.attention_size], # initializer=tf.random_uniform_initializer(-0.25, 0.25), # trainable=True) # self.b_omega = self.add_weight("b_omega", shape=[self.attention_size], # initializer=tf.random_uniform_initializer(-0.25, 0.25), # trainable=True) # self.u_omega = self.add_weight("u_omega", shape=[self.attention_size], # initializer=tf.random_uniform_initializer(-0.25, 0.25), # trainable=True) # Bi-LSTM # self.forward_cell = LSTMCell(self.hidden_size, forget_bias=1.0, state_is_tuple=True) # self.backward_cell = LSTMCell(self.hidden_size, forget_bias=1.0, state_is_tuple=True) # self.bi_lism = Bidirectional(LSTM(self.hidden_size, return_sequences=True)) # self.bi_lism.build(input_shape=(None, self.character_num, self.character_embed)) # self.trainable_weights += self.bi_lism.trainable_weights # # self.self_attention = SeqSelfAttention(attention_activation='sigmoid') # self.self_attention.build(input_shape=(None, self.character_num, 2*self.hidden_size)) # self.trainable_weights += self.self_attention.trainable_weights # print(self.trainable_weights) # DNN # self.w1 = self.add_weight('W1', [2*self.attention_size, self.cell_embed], # initializer=tf.random_uniform_initializer(-0.25, 0.25), # trainable=True) # # self.b1 = self.add_weight('b1', [self.cell_embed], # initializer=tf.zeros_initializer(), # trainable=True) # self.dense = Dense(self.cell_embed, activation="relu") # print(batch_input_shape[0], batch_input_shape[1], batch_input_shape[2]) # self.dense.build(input_shape=(batch_input_shape[0]*batch_input_shape[1]*batch_input_shape[2], # 2*self.attention_size)) # self.trainable_weights += self.dense.trainable_weights def call(self, inputs, mask=None, **kwargs): def fn(x): print("fn_0", x) # (batch*height*width, character_num, hidden_size) # outputs, last_states = tf.nn.bidirectional_dynamic_rnn(cell_fw=self.forward_cell, # cell_bw=self.backward_cell, # inputs=x, # dtype=tf.float32, # time_major=False) # (batch*height*width, character_num, 2*hidden_size) # outputs = self.bi_lism(x) # print("fn_1", outputs) # (batch*height*width, character_num, 2*hidden_size) # outputs = self.self_attention(outputs) # print("fn_2", outputs) # (batch*height*width, 2*hidden_size) # outputs, _ = attention(outputs, self.w_omega, self.b_omega, # self.u_omega, time_major=False) # (batch*height*width, cell_embedding) # outputs = tf.nn.xw_plus_b(outputs, self.w1, self.b1) # outputs = self.dense(outputs) # print("fn_3", outputs) return outputs batch = tf.shape(inputs)[0] height = tf.shape(inputs)[1] width = tf.shape(inputs)[2] # (batch, height*width, character_num(time_step), character_embedding) # inputs = tf.reshape(inputs, (tf.shape(inputs)[0], # height*width, # inputs.shape[3], inputs.shape[4])) # (batch*height*width, character_num, character_embedding) outputs = tf.reshape(inputs, (batch*height*width, inputs.shape[3], inputs.shape[4])) # (height*width, batch, character_num(time_step), character_embedding) # inputs = tf.transpose(inputs, (1, 0, 2, 3)) # split height*width, each cell # (height*width, batch, cell_embedding) # outputs = tf.map_fn(fn=lambda x: fn(x), elems=inputs, dtype=tf.float32) # print("loop_lstm_1", outputs) # outputs = tf.squeeze(outputs, 0) # (batch*height*width, 2*attention_size) # outputs = fn(inputs) # print("loop_lstm_2", outputs) # (1, batch*height*width, 2*attention_size) # outputs = tf.expand_dims(outputs, 0) # print("loop_lstm_3", outputs) # (batch*height*width, cell_embedding) # outputs = Dense(self.cell_embed, activation="relu")(outputs) # print("loop_lstm_3", outputs) # (batch, height*width, cell_embedding) # outputs = tf.transpose(outputs, (1, 0, 2)) # print("loop_lstm_2", outputs) # (batch, height, width, cell_embedding) # outputs = tf.reshape(outputs, (batch, height, width, self.cell_embed)) # print("loop_lstm_4", outputs) return outputs def compute_output_shape(self, input_shape): return None, self.character_num, self.character_embed class BatchReshape(Layer): def __init__(self, cell_embed): super(BatchReshape, self).__init__() self.cell_embed = cell_embed def call(self, inputs, mask=None, **kwargs): input1 = inputs[0] input2 = inputs[1] batch = tf.shape(input1)[0] height = tf.shape(input1)[1] width = tf.shape(input1)[2] # (batch, height, width, cell_embedding) outputs = tf.reshape(input2, (batch, height, width, self.cell_embed)) print("batch_reshape", outputs) return outputs def compute_output_shape(self, input_shape): return None, None, None, self.cell_embed # def batch_reshape(x): # return K.reshape(x, (batch, height, width, cell_embed)) if __name__ == '__main__': input_shape = (16, 8, 10, 60) hidden_size = 64 attention_size = 64 character_num = 10 character_embed = 60 cell_embed = 8 # (batch_size, row_num, col_num, character_num, character_embedding) X_train = np.random.uniform(0, 1, (10, 16, 8, 10, 60)) X_test = np.random.uniform(0, 1, (10, 16, 8, 10, 60)) y_train = np.random.uniform(0, 1, (10, 16, 8)) y_test = np.random.uniform(0, 1, (10, 16, 8)) _input = Input(shape=input_shape, dtype="float32") batch = K.shape(_input)[0] height = K.shape(_input)[1] width = K.shape(_input)[2] print(batch, height, width) loop_bi_lstm = LoopCell(hidden_size, attention_size, character_num, character_embed, cell_embed)(_input) print("model_2_1", loop_bi_lstm) dense = Dense(cell_embed, activation="relu")(loop_bi_lstm) print("model_2_2", dense) reshape = Lambda(batch_reshape, output_shape=(height, width, cell_embed))(dense) print("model_2_3", reshape) u_net = u_net_small(loop_bi_lstm) merge = Concatenate(axis=-1)([loop_bi_lstm, u_net]) dense = Dense(LoopCell().cell_embed, activation='relu')(merge) dense = Dense(1, activation='sigmoid')(dense) squeeze = Lambda(lambda x: K.squeeze(x, axis=-1))(dense) model = models.Model(inputs=_input, outputs=squeeze) model.summary(line_length=120) model.compile(loss='binary_crossentropy', optimizer='adam') model.fit(X_train, y_train, epochs=2, batch_size=1, validation_data=(X_test, y_test)) # (batch_size, row_num, col_num, character_num, character_embedding) X_train = np.random.uniform(0, 1, (5, 32, 24, 10, 60)) X_test = np.random.uniform(0, 1, (5, 32, 24, 10, 60)) y_train = np.random.uniform(0, 1, (5, 32, 24)) y_test = np.random.uniform(0, 1, (5, 32, 24)) model.fit(X_train, y_train, epochs=2, batch_size=1, validation_data=(X_test, y_test))