models.py 8.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224
  1. from layers import *
  2. from metrics import *
  3. from inits import *
  4. flags = tf.app.flags
  5. FLAGS = flags.FLAGS
  6. class Model(object):
  7. def __init__(self, **kwargs):
  8. allowed_kwargs = {'name', 'logging'}
  9. for kwarg in kwargs.keys():
  10. assert kwarg in allowed_kwargs, 'Invalid keyword argument: ' + kwarg
  11. name = kwargs.get('name')
  12. if not name:
  13. name = self.__class__.__name__.lower()
  14. self.name = name
  15. logging = kwargs.get('logging', False)
  16. self.logging = logging
  17. self.vars = {}
  18. self.placeholders = {}
  19. self.layers = []
  20. self.activations = []
  21. self.inputs = None
  22. self.outputs = None
  23. self.loss = 0
  24. self.accuracy = 0
  25. self.optimizer = None
  26. self.opt_op = None
  27. def _build(self):
  28. raise NotImplementedError
  29. def build(self):
  30. """ Wrapper for _build() """
  31. with tf.variable_scope(self.name):
  32. self._build()
  33. # Build sequential layer model
  34. self.activations.append(self.inputs)
  35. for layer in self.layers:
  36. hidden = layer(self.activations[-1])
  37. self.activations.append(hidden)
  38. self.outputs = self.activations[-1]
  39. # Store model variables for easy access
  40. variables = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope=self.name)
  41. self.vars = {var.name: var for var in variables}
  42. # Build metrics
  43. self._loss()
  44. self._accuracy()
  45. self.opt_op = self.optimizer.minimize(self.loss)
  46. def predict(self):
  47. pass
  48. def _loss(self):
  49. raise NotImplementedError
  50. def _accuracy(self):
  51. raise NotImplementedError
  52. def save(self, sess=None):
  53. if not sess:
  54. raise AttributeError("TensorFlow session not provided.")
  55. saver = tf.train.Saver(self.vars)
  56. save_path = saver.save(sess, "tmp/%s.ckpt" % self.name)
  57. print("Model saved in file: %s" % save_path)
  58. def load(self, sess=None):
  59. if not sess:
  60. raise AttributeError("TensorFlow session not provided.")
  61. saver = tf.train.Saver(self.vars)
  62. save_path = "tmp/%s.ckpt" % self.name
  63. saver.restore(sess, save_path)
  64. print("Model restored from file: %s" % save_path)
  65. class MLP(Model):
  66. def __init__(self, placeholders, input_dim, **kwargs):
  67. super(MLP, self).__init__(**kwargs)
  68. self.inputs = placeholders['features']
  69. self.input_dim = input_dim
  70. # self.input_dim = self.inputs.get_shape().as_list()[1] # To be supported in future Tensorflow versions
  71. self.output_dim = placeholders['labels'].get_shape().as_list()[1]
  72. self.placeholders = placeholders
  73. self.optimizer = tf.train.AdamOptimizer(learning_rate=FLAGS.learning_rate)
  74. self.build()
  75. def _loss(self):
  76. # Weight decay loss
  77. for var in self.layers[0].vars.values():
  78. self.loss += FLAGS.weight_decay * tf.nn.l2_loss(var)
  79. # Cross entropy error
  80. self.loss += masked_softmax_cross_entropy(self.outputs, self.placeholders['labels'],
  81. self.placeholders['labels_mask'])
  82. def _accuracy(self):
  83. self.accuracy = masked_accuracy(self.outputs, self.placeholders['labels'],
  84. self.placeholders['labels_mask'])
  85. def _build(self):
  86. self.layers.append(Dense(input_dim=self.input_dim,
  87. output_dim=FLAGS.hidden1,
  88. placeholders=self.placeholders,
  89. act=tf.nn.relu,
  90. dropout=True,
  91. sparse_inputs=True,
  92. logging=self.logging))
  93. self.layers.append(Dense(input_dim=FLAGS.hidden1,
  94. output_dim=self.output_dim,
  95. placeholders=self.placeholders,
  96. act=lambda x: x,
  97. dropout=True,
  98. logging=self.logging))
  99. def predict(self):
  100. return tf.nn.softmax(self.outputs)
  101. class GCN(Model):
  102. def __init__(self, placeholders, input_dim, **kwargs):
  103. super(GCN, self).__init__(**kwargs)
  104. self.inputs = placeholders['features']
  105. self.input_dim = input_dim
  106. # self.input_dim = self.inputs.get_shape().as_list()[1] # To be supported in future Tensorflow versions
  107. self.output_dim = placeholders['labels'].get_shape().as_list()[1]
  108. self.placeholders = placeholders
  109. self.optimizer = tf.train.AdamOptimizer(learning_rate=FLAGS.learning_rate)
  110. self.build()
  111. def _loss(self):
  112. # Weight decay loss
  113. for var in self.layers[0].vars.values():
  114. self.loss += FLAGS.weight_decay * tf.nn.l2_loss(var)
  115. # Cross entropy error
  116. self.loss += masked_softmax_cross_entropy(self.outputs, self.placeholders['labels'],
  117. self.placeholders['labels_mask'])
  118. def _accuracy(self):
  119. self.accuracy = masked_accuracy(self.outputs, self.placeholders['labels'],
  120. self.placeholders['labels_mask'])
  121. def _build(self):
  122. self.layers.append(GraphConvolution(input_dim=self.input_dim,
  123. output_dim=FLAGS.hidden1,
  124. placeholders=self.placeholders,
  125. act=tf.nn.relu,
  126. dropout=True,
  127. sparse_inputs=True,
  128. logging=self.logging))
  129. self.layers.append(GraphConvolution(input_dim=FLAGS.hidden1,
  130. output_dim=self.output_dim,
  131. placeholders=self.placeholders,
  132. act=lambda x: x,
  133. dropout=True,
  134. logging=self.logging))
  135. def predict(self):
  136. return tf.nn.softmax(self.outputs)
  137. class GCN_Align(Model):
  138. def __init__(self, placeholders, input_dim, output_dim, ILL,
  139. sparse_inputs=False, featureless=True, AE=True, **kwargs):
  140. super(GCN_Align, self).__init__(**kwargs)
  141. self.inputs = placeholders['features']
  142. self.input_dim = input_dim
  143. self.output_dim = output_dim
  144. self.placeholders = placeholders
  145. self.ILL = ILL
  146. self.sparse_inputs = sparse_inputs
  147. self.featureless = featureless
  148. self.AE = AE
  149. self.optimizer = tf.train.GradientDescentOptimizer(learning_rate=FLAGS.learning_rate)
  150. self.build()
  151. def _loss(self):
  152. self.loss += align_loss(self.outputs, self.ILL, FLAGS.gamma, FLAGS.k, AE=self.AE)
  153. def _accuracy(self):
  154. pass
  155. def _build(self):
  156. self.layers.append(GraphConvolution(input_dim=self.input_dim,
  157. output_dim=self.output_dim,
  158. placeholders=self.placeholders,
  159. act=tf.nn.relu,
  160. dropout=False,
  161. featureless=self.featureless,
  162. sparse_inputs=self.sparse_inputs,
  163. transform=False,
  164. init=trunc_normal,
  165. logging=self.logging))
  166. self.layers.append(GraphConvolution(input_dim=self.output_dim,
  167. output_dim=self.output_dim,
  168. placeholders=self.placeholders,
  169. act=lambda x: x,
  170. dropout=False,
  171. transform=False,
  172. logging=self.logging))