layers.py 6.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194
  1. from inits import *
  2. import tensorflow as tf
  3. flags = tf.app.flags
  4. FLAGS = flags.FLAGS
  5. # global unique layer ID dictionary for layer name assignment
  6. _LAYER_UIDS = {}
  7. def get_layer_uid(layer_name=''):
  8. """Helper function, assigns unique layer IDs."""
  9. if layer_name not in _LAYER_UIDS:
  10. _LAYER_UIDS[layer_name] = 1
  11. return 1
  12. else:
  13. _LAYER_UIDS[layer_name] += 1
  14. return _LAYER_UIDS[layer_name]
  15. def sparse_dropout(x, keep_prob, noise_shape):
  16. """Dropout for sparse tensors."""
  17. random_tensor = keep_prob
  18. random_tensor += tf.random_uniform(noise_shape)
  19. dropout_mask = tf.cast(tf.floor(random_tensor), dtype=tf.bool)
  20. pre_out = tf.sparse_retain(x, dropout_mask)
  21. return pre_out * (1./keep_prob)
  22. def dot(x, y, sparse=False):
  23. """Wrapper for tf.matmul (sparse vs dense)."""
  24. if sparse:
  25. res = tf.sparse_tensor_dense_matmul(x, y)
  26. else:
  27. res = tf.matmul(x, y)
  28. return res
  29. class Layer(object):
  30. """Base layer class. Defines basic API for all layer objects.
  31. Implementation inspired by keras (http://keras.io).
  32. # Properties
  33. name: String, defines the variable scope of the layer.
  34. logging: Boolean, switches Tensorflow histogram logging on/off
  35. # Methods
  36. _call(inputs): Defines computation graph of layer
  37. (i.e. takes input, returns output)
  38. __call__(inputs): Wrapper for _call()
  39. _log_vars(): Log all variables
  40. """
  41. def __init__(self, **kwargs):
  42. allowed_kwargs = {'name', 'logging'}
  43. for kwarg in kwargs.keys():
  44. assert kwarg in allowed_kwargs, 'Invalid keyword argument: ' + kwarg
  45. name = kwargs.get('name')
  46. if not name:
  47. layer = self.__class__.__name__.lower()
  48. name = layer + '_' + str(get_layer_uid(layer))
  49. self.name = name
  50. self.vars = {}
  51. logging = kwargs.get('logging', False)
  52. self.logging = logging
  53. self.sparse_inputs = False
  54. def _call(self, inputs):
  55. return inputs
  56. def __call__(self, inputs):
  57. with tf.name_scope(self.name):
  58. if self.logging and not self.sparse_inputs:
  59. tf.summary.histogram(self.name + '/inputs', inputs)
  60. outputs = self._call(inputs)
  61. if self.logging:
  62. tf.summary.histogram(self.name + '/outputs', outputs)
  63. return outputs
  64. def _log_vars(self):
  65. for var in self.vars:
  66. tf.summary.histogram(self.name + '/vars/' + var, self.vars[var])
  67. class Dense(Layer):
  68. """Dense layer."""
  69. def __init__(self, input_dim, output_dim, placeholders, dropout=0., sparse_inputs=False,
  70. act=tf.nn.relu, bias=False, featureless=False, **kwargs):
  71. super(Dense, self).__init__(**kwargs)
  72. if dropout:
  73. self.dropout = placeholders['dropout']
  74. else:
  75. self.dropout = 0.
  76. self.act = act
  77. self.sparse_inputs = sparse_inputs
  78. self.featureless = featureless
  79. self.bias = bias
  80. # helper variable for sparse dropout
  81. self.num_features_nonzero = placeholders['num_features_nonzero']
  82. with tf.variable_scope(self.name + '_vars'):
  83. self.vars['weights'] = glorot([input_dim, output_dim],
  84. name='weights')
  85. if self.bias:
  86. self.vars['bias'] = zeros([output_dim], name='bias')
  87. if self.logging:
  88. self._log_vars()
  89. def _call(self, inputs):
  90. x = inputs
  91. # dropout
  92. if self.sparse_inputs:
  93. x = sparse_dropout(x, 1-self.dropout, self.num_features_nonzero)
  94. else:
  95. x = tf.nn.dropout(x, 1-self.dropout)
  96. # transform
  97. output = dot(x, self.vars['weights'], sparse=self.sparse_inputs)
  98. # bias
  99. if self.bias:
  100. output += self.vars['bias']
  101. return self.act(output)
  102. class GraphConvolution(Layer):
  103. """Graph convolution layer. (featureless=True and transform=False) is not supported for now."""
  104. def __init__(self, input_dim, output_dim, placeholders, dropout=0.,
  105. sparse_inputs=False, act=tf.nn.relu, bias=False,
  106. featureless=False, transform=True, init=glorot, **kwargs):
  107. super(GraphConvolution, self).__init__(**kwargs)
  108. if dropout:
  109. self.dropout = placeholders['dropout']
  110. else:
  111. self.dropout = 0.
  112. self.act = act
  113. self.support = placeholders['support']
  114. self.sparse_inputs = sparse_inputs
  115. self.featureless = featureless
  116. self.bias = bias
  117. self.transform = transform
  118. # helper variable for sparse dropout
  119. self.num_features_nonzero = placeholders['num_features_nonzero']
  120. with tf.variable_scope(self.name + '_vars'):
  121. for i in range(len(self.support)):
  122. if input_dim == output_dim and not self.transform and not featureless:
  123. continue
  124. self.vars['weights_' + str(i)] = init([input_dim, output_dim],
  125. name='weights_' + str(i))
  126. if self.bias:
  127. self.vars['bias'] = zeros([output_dim], name='bias')
  128. if self.logging:
  129. self._log_vars()
  130. def _call(self, inputs):
  131. x = inputs
  132. # dropout
  133. if self.dropout:
  134. if self.sparse_inputs:
  135. x = sparse_dropout(x, 1-self.dropout, self.num_features_nonzero)
  136. else:
  137. x = tf.nn.dropout(x, 1-self.dropout)
  138. # convolve
  139. supports = list()
  140. for i in range(len(self.support)):
  141. if 'weights_'+str(i) in self.vars:
  142. if not self.featureless:
  143. pre_sup = dot(x, self.vars['weights_' + str(i)], sparse=self.sparse_inputs)
  144. else:
  145. pre_sup = self.vars['weights_' + str(i)]
  146. else:
  147. pre_sup = x
  148. support = dot(self.support[i], pre_sup, sparse=True)
  149. supports.append(support)
  150. output = tf.add_n(supports)
  151. # bias
  152. if self.bias:
  153. output += self.vars['bias']
  154. return self.act(output)