loop_lstm.py 9.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232
  1. import keras
  2. import tensorflow as tf
  3. from keras import models, backend as K
  4. from keras.layers import Layer, Input, Lambda, Concatenate, Dense, LSTM, Bidirectional
  5. from tensorflow.contrib.rnn import LSTMCell
  6. import numpy as np
  7. from BiddingKG.dl.table_head.models.self_attention import SeqSelfAttention
  8. from BiddingKG.dl.table_head.models.u_net import u_net_small
  9. def attention(inputs, w_omega, b_omega, u_omega, time_major=False):
  10. if isinstance(inputs, tuple):
  11. inputs = tf.concat(inputs, 2)
  12. if time_major: # (B,T,D) => (T,B,D)
  13. inputs = tf.transpose(inputs, [1, 0, 2])
  14. v = tf.tanh(tf.tensordot(inputs, w_omega, axes=1) + b_omega)
  15. vu = tf.tensordot(v, u_omega, axes=1, name='vu') # (B,T) shape
  16. alphas = tf.nn.softmax(vu, name='alphas') # (B,T) shape
  17. # the result has (B,D) shape
  18. output = tf.reduce_sum(inputs * tf.expand_dims(alphas, -1), 1)
  19. return output, alphas
  20. class LoopCell(Layer):
  21. def __init__(self, hidden_size, attention_size, character_num, character_embed,
  22. cell_embed):
  23. super(LoopCell, self).__init__()
  24. # Hyper parameters
  25. self.hidden_size = hidden_size
  26. self.attention_size = attention_size
  27. self.character_num = character_num
  28. self.character_embed = character_embed
  29. self.cell_embed = cell_embed
  30. def build(self, batch_input_shape):
  31. super(LoopCell, self).build(batch_input_shape)
  32. # Trainable parameters
  33. # Attention
  34. # self.w_omega = self.add_weight("w_omega", shape=[self.hidden_size*2, self.attention_size],
  35. # initializer=tf.random_uniform_initializer(-0.25, 0.25),
  36. # trainable=True)
  37. # self.b_omega = self.add_weight("b_omega", shape=[self.attention_size],
  38. # initializer=tf.random_uniform_initializer(-0.25, 0.25),
  39. # trainable=True)
  40. # self.u_omega = self.add_weight("u_omega", shape=[self.attention_size],
  41. # initializer=tf.random_uniform_initializer(-0.25, 0.25),
  42. # trainable=True)
  43. # Bi-LSTM
  44. # self.forward_cell = LSTMCell(self.hidden_size, forget_bias=1.0, state_is_tuple=True)
  45. # self.backward_cell = LSTMCell(self.hidden_size, forget_bias=1.0, state_is_tuple=True)
  46. # self.bi_lism = Bidirectional(LSTM(self.hidden_size, return_sequences=True))
  47. # self.bi_lism.build(input_shape=(None, self.character_num, self.character_embed))
  48. # self.trainable_weights += self.bi_lism.trainable_weights
  49. #
  50. # self.self_attention = SeqSelfAttention(attention_activation='sigmoid')
  51. # self.self_attention.build(input_shape=(None, self.character_num, 2*self.hidden_size))
  52. # self.trainable_weights += self.self_attention.trainable_weights
  53. # print(self.trainable_weights)
  54. # DNN
  55. # self.w1 = self.add_weight('W1', [2*self.attention_size, self.cell_embed],
  56. # initializer=tf.random_uniform_initializer(-0.25, 0.25),
  57. # trainable=True)
  58. #
  59. # self.b1 = self.add_weight('b1', [self.cell_embed],
  60. # initializer=tf.zeros_initializer(),
  61. # trainable=True)
  62. # self.dense = Dense(self.cell_embed, activation="relu")
  63. # print(batch_input_shape[0], batch_input_shape[1], batch_input_shape[2])
  64. # self.dense.build(input_shape=(batch_input_shape[0]*batch_input_shape[1]*batch_input_shape[2],
  65. # 2*self.attention_size))
  66. # self.trainable_weights += self.dense.trainable_weights
  67. def call(self, inputs, mask=None, **kwargs):
  68. def fn(x):
  69. print("fn_0", x)
  70. # (batch*height*width, character_num, hidden_size)
  71. # outputs, last_states = tf.nn.bidirectional_dynamic_rnn(cell_fw=self.forward_cell,
  72. # cell_bw=self.backward_cell,
  73. # inputs=x,
  74. # dtype=tf.float32,
  75. # time_major=False)
  76. # (batch*height*width, character_num, 2*hidden_size)
  77. # outputs = self.bi_lism(x)
  78. # print("fn_1", outputs)
  79. # (batch*height*width, character_num, 2*hidden_size)
  80. # outputs = self.self_attention(outputs)
  81. # print("fn_2", outputs)
  82. # (batch*height*width, 2*hidden_size)
  83. # outputs, _ = attention(outputs, self.w_omega, self.b_omega,
  84. # self.u_omega, time_major=False)
  85. # (batch*height*width, cell_embedding)
  86. # outputs = tf.nn.xw_plus_b(outputs, self.w1, self.b1)
  87. # outputs = self.dense(outputs)
  88. # print("fn_3", outputs)
  89. return outputs
  90. batch = tf.shape(inputs)[0]
  91. height = tf.shape(inputs)[1]
  92. width = tf.shape(inputs)[2]
  93. # (batch, height*width, character_num(time_step), character_embedding)
  94. # inputs = tf.reshape(inputs, (tf.shape(inputs)[0],
  95. # height*width,
  96. # inputs.shape[3], inputs.shape[4]))
  97. # (batch*height*width, character_num, character_embedding)
  98. outputs = tf.reshape(inputs, (batch*height*width,
  99. inputs.shape[3], inputs.shape[4]))
  100. # (height*width, batch, character_num(time_step), character_embedding)
  101. # inputs = tf.transpose(inputs, (1, 0, 2, 3))
  102. # split height*width, each cell
  103. # (height*width, batch, cell_embedding)
  104. # outputs = tf.map_fn(fn=lambda x: fn(x), elems=inputs, dtype=tf.float32)
  105. # print("loop_lstm_1", outputs)
  106. # outputs = tf.squeeze(outputs, 0)
  107. # (batch*height*width, 2*attention_size)
  108. # outputs = fn(inputs)
  109. # print("loop_lstm_2", outputs)
  110. # (1, batch*height*width, 2*attention_size)
  111. # outputs = tf.expand_dims(outputs, 0)
  112. # print("loop_lstm_3", outputs)
  113. # (batch*height*width, cell_embedding)
  114. # outputs = Dense(self.cell_embed, activation="relu")(outputs)
  115. # print("loop_lstm_3", outputs)
  116. # (batch, height*width, cell_embedding)
  117. # outputs = tf.transpose(outputs, (1, 0, 2))
  118. # print("loop_lstm_2", outputs)
  119. # (batch, height, width, cell_embedding)
  120. # outputs = tf.reshape(outputs, (batch, height, width, self.cell_embed))
  121. # print("loop_lstm_4", outputs)
  122. return outputs
  123. def compute_output_shape(self, input_shape):
  124. return None, self.character_num, self.character_embed
  125. class BatchReshape(Layer):
  126. def __init__(self, cell_embed):
  127. super(BatchReshape, self).__init__()
  128. self.cell_embed = cell_embed
  129. def call(self, inputs, mask=None, **kwargs):
  130. input1 = inputs[0]
  131. input2 = inputs[1]
  132. batch = tf.shape(input1)[0]
  133. height = tf.shape(input1)[1]
  134. width = tf.shape(input1)[2]
  135. # (batch, height, width, cell_embedding)
  136. outputs = tf.reshape(input2, (batch, height, width, self.cell_embed))
  137. print("batch_reshape", outputs)
  138. return outputs
  139. def compute_output_shape(self, input_shape):
  140. return None, None, None, self.cell_embed
  141. # def batch_reshape(x):
  142. # return K.reshape(x, (batch, height, width, cell_embed))
  143. if __name__ == '__main__':
  144. input_shape = (16, 8, 10, 60)
  145. hidden_size = 64
  146. attention_size = 64
  147. character_num = 10
  148. character_embed = 60
  149. cell_embed = 8
  150. # (batch_size, row_num, col_num, character_num, character_embedding)
  151. X_train = np.random.uniform(0, 1, (10, 16, 8, 10, 60))
  152. X_test = np.random.uniform(0, 1, (10, 16, 8, 10, 60))
  153. y_train = np.random.uniform(0, 1, (10, 16, 8))
  154. y_test = np.random.uniform(0, 1, (10, 16, 8))
  155. _input = Input(shape=input_shape, dtype="float32")
  156. batch = K.shape(_input)[0]
  157. height = K.shape(_input)[1]
  158. width = K.shape(_input)[2]
  159. print(batch, height, width)
  160. loop_bi_lstm = LoopCell(hidden_size, attention_size,
  161. character_num, character_embed,
  162. cell_embed)(_input)
  163. print("model_2_1", loop_bi_lstm)
  164. dense = Dense(cell_embed, activation="relu")(loop_bi_lstm)
  165. print("model_2_2", dense)
  166. reshape = Lambda(batch_reshape, output_shape=(height, width, cell_embed))(dense)
  167. print("model_2_3", reshape)
  168. u_net = u_net_small(loop_bi_lstm)
  169. merge = Concatenate(axis=-1)([loop_bi_lstm, u_net])
  170. dense = Dense(LoopCell().cell_embed, activation='relu')(merge)
  171. dense = Dense(1, activation='sigmoid')(dense)
  172. squeeze = Lambda(lambda x: K.squeeze(x, axis=-1))(dense)
  173. model = models.Model(inputs=_input, outputs=squeeze)
  174. model.summary(line_length=120)
  175. model.compile(loss='binary_crossentropy', optimizer='adam')
  176. model.fit(X_train, y_train,
  177. epochs=2,
  178. batch_size=1,
  179. validation_data=(X_test, y_test))
  180. # (batch_size, row_num, col_num, character_num, character_embedding)
  181. X_train = np.random.uniform(0, 1, (5, 32, 24, 10, 60))
  182. X_test = np.random.uniform(0, 1, (5, 32, 24, 10, 60))
  183. y_train = np.random.uniform(0, 1, (5, 32, 24))
  184. y_test = np.random.uniform(0, 1, (5, 32, 24))
  185. model.fit(X_train, y_train,
  186. epochs=2,
  187. batch_size=1,
  188. validation_data=(X_test, y_test))