nodes.py 9.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352
  1. """Module for nodes of factor graphs.
  2. This module contains classes for nodes of factor graphs,
  3. which are used to build factor graphs.
  4. Classes:
  5. Node: Abstract class for nodes.
  6. VNode: Class for variable nodes.
  7. IOVNode: Class for custom input-output variable nodes.
  8. FNode: Class for factor nodes.
  9. IOFNode: Class for custom input-output factor nodes.
  10. """
  11. from abc import ABC, abstractmethod, abstractproperty
  12. from enum import Enum
  13. from types import MethodType
  14. import networkx as nx
  15. import numpy as np
  16. from . import rv
  17. class NodeType(Enum):
  18. """Enumeration for node types."""
  19. variable_node = 1
  20. factor_node = 2
  21. class Node(ABC):
  22. """Abstract base class for all nodes."""
  23. def __init__(self, label):
  24. """Create a node with an associated label."""
  25. self.__label = str(label)
  26. self.__graph = None
  27. def __str__(self):
  28. """Return string representation."""
  29. return self.__label
  30. @abstractproperty
  31. def type(self):
  32. pass
  33. @property
  34. def graph(self):
  35. return self.__graph
  36. @graph.setter
  37. def graph(self, graph):
  38. self.__graph = graph
  39. def neighbors(self, exclusion=None):
  40. """Get all neighbors with a given exclusion.
  41. Return iterator over all neighboring nodes
  42. without the given exclusion node.
  43. Positional arguments:
  44. exclusion -- the exclusion node
  45. """
  46. if exclusion is None:
  47. return nx.all_neighbors(self.graph, self)
  48. else:
  49. # Build iterator set
  50. iterator = (exclusion,) \
  51. if not isinstance(exclusion, list) else exclusion
  52. # Return neighbors excluding iterator set
  53. return (n for n in nx.all_neighbors(self.graph, self)
  54. if n not in iterator)
  55. @abstractmethod
  56. def spa(self, tnode):
  57. pass
  58. @abstractmethod
  59. def mpa(self, tnode):
  60. pass
  61. @abstractmethod
  62. def msa(self, tnode):
  63. pass
  64. @abstractmethod
  65. def mf(self, tnode):
  66. pass
  67. class VNode(Node):
  68. """Variable node.
  69. Variable node inherited from node base class.
  70. Extends the base class with message passing methods.
  71. """
  72. def __init__(self, label, rv_type, observed=False):
  73. """Create a variable node."""
  74. super().__init__(label)
  75. self.init = rv_type.unity(self)
  76. self.observed = observed
  77. @property
  78. def type(self):
  79. return NodeType.variable_node
  80. @property
  81. def init(self):
  82. return self.__init
  83. @init.setter
  84. def init(self, init):
  85. self.__init = init
  86. def belief(self, normalize=True):
  87. """Return belief of the variable node.
  88. Args:
  89. normalize: Boolean flag if belief should be normalized.
  90. """
  91. iterator = self.graph.neighbors(self)
  92. # Pick first node
  93. n = next(iterator)
  94. # Product over all incoming messages
  95. belief = self.graph[n][self]['object'].get_message(n, self)
  96. for n in iterator:
  97. belief *= self.graph[n][self]['object'].get_message(n, self)
  98. if normalize:
  99. belief = belief.normalize()
  100. return belief
  101. def maximum(self, normalize=True):
  102. """Return the maximum probability of the variable node.
  103. Args:
  104. normalize: Boolean flag if belief should be normalized.
  105. """
  106. b = self.belief(normalize)
  107. return np.amax(b.pmf)
  108. def argmax(self):
  109. """Return the argument for maximum probability of the variable node."""
  110. # In case of multiple occurrences of the maximum values,
  111. # the indices corresponding to the first occurrence are returned.
  112. b = self.belief()
  113. return b.argmax(self)
  114. def spa(self, tnode):
  115. """Return message of the sum-product algorithm."""
  116. if self.observed:
  117. return self.init
  118. else:
  119. # Initial message
  120. msg = self.init
  121. # Product over incoming messages
  122. for n in self.neighbors(tnode):
  123. msg *= self.graph[n][self]['object'].get_message(n, self)
  124. return msg
  125. def mpa(self, tnode):
  126. """Return message of the max-product algorithm."""
  127. return self.spa(tnode)
  128. def msa(self, tnode):
  129. """Return message of the max-sum algorithm."""
  130. if self.observed:
  131. return self.init.log()
  132. else:
  133. # Initial (logarithmized) message
  134. msg = self.init.log()
  135. # Sum over incoming messages
  136. for n in self.neighbors(tnode):
  137. msg += self.graph[n][self]['object'].get_message(n, self)
  138. return msg
  139. def mf(self, tnode):
  140. """Return message of the mean-field algorithm."""
  141. if self.observed:
  142. return self.init
  143. else:
  144. return self.belief(self.graph)
  145. class IOVNode(VNode):
  146. """Input-output variable node.
  147. Input-output variable node inherited from variable node class.
  148. Overwrites all message passing methods of the base class
  149. with a given callback function.
  150. """
  151. def __init__(self, label, init=None, observed=False, callback=None):
  152. """Create an input-output variable node."""
  153. super().__init__(label, init, observed)
  154. if callback is not None:
  155. self.set_callback(callback)
  156. def set_callback(self, callback):
  157. """Set callback function.
  158. Add bounded methods to the class instance in order to overwrite
  159. the existing message passing methods.
  160. """
  161. self.spa = MethodType(callback, self)
  162. self.mpa = MethodType(callback, self)
  163. self.msa = MethodType(callback, self)
  164. self.mf = MethodType(callback, self)
  165. class FNode(Node):
  166. """Factor node.
  167. Factor node inherited from node base class.
  168. Extends the base class with message passing methods.
  169. """
  170. def __init__(self, label, factor=None):
  171. """Create a factor node."""
  172. super().__init__(label)
  173. self.factor = factor
  174. self.record = {}
  175. @property
  176. def type(self):
  177. return NodeType.factor_node
  178. @property
  179. def factor(self):
  180. return self.__factor
  181. @factor.setter
  182. def factor(self, factor):
  183. self.__factor = factor
  184. def spa(self, tnode):
  185. """Return message of the sum-product algorithm."""
  186. # Initialize with local factor
  187. msg = self.factor
  188. # Product over incoming messages
  189. for n in self.neighbors(tnode):
  190. msg *= self.graph[n][self]['object'].get_message(n, self)
  191. # Integration/Summation over incoming variables
  192. for n in self.neighbors(tnode):
  193. msg = msg.marginalize(n, normalize=False)
  194. return msg
  195. def mpa(self, tnode):
  196. """Return message of the max-product algorithm."""
  197. self.record[tnode] = {}
  198. # Initialize with local factor
  199. msg = self.factor
  200. # Product over incoming messages
  201. for n in self.neighbors(tnode):
  202. msg *= self.graph[n][self]['object'].get_message(n, self)
  203. # Maximization over incoming variables
  204. for n in self.neighbors(tnode):
  205. self.record[tnode][n] = msg.argmax(n) # Record for back-tracking
  206. msg = msg.maximize(n, normalize=False)
  207. return msg
  208. def msa(self, tnode):
  209. """Return message of the max-sum algorithm."""
  210. self.record[tnode] = {}
  211. # Initialize with (logarithmized) local factor
  212. msg = self.factor.log()
  213. # Sum over incoming messages
  214. for n in self.neighbors(tnode):
  215. msg += self.graph[n][self]['object'].get_message(n, self)
  216. # Maximization over incoming variables
  217. for n in self.neighbors(tnode):
  218. self.record[tnode][n] = msg.argmax(n) # Record for back-tracking
  219. msg = msg.maximize(n, normalize=False)
  220. return msg
  221. def mf(self, tnode):
  222. """Return message of the mean-field algorithm."""
  223. # Initialize with local factor
  224. msg = self.factor
  225. # # Product over incoming messages
  226. # for n in self.neighbors(graph, self, tnode):
  227. # msg *= graph[n][self]['object'].get_message(n, self)
  228. #
  229. # # Integration/Summation over incoming variables
  230. # for n in self.neighbors(graph, self, tnode):
  231. # msg = msg.int(n)
  232. return msg
  233. class IOFNode(FNode):
  234. """Input-output factor node.
  235. Input-output factor node inherited from factor node class.
  236. Overwrites all message passing methods of the base class
  237. with a given callback function.
  238. """
  239. def __init__(self, label, factor, callback=None):
  240. """Create an input-output factor node."""
  241. super().__init__(self, label, factor)
  242. if callback is not None:
  243. self.set_callback(callback)
  244. def set_callback(self, callback):
  245. """Set callback function.
  246. Add bounded methods to the class instance in order to overwrite
  247. the existing message passing methods.
  248. """
  249. self.spa = MethodType(callback, self)
  250. self.mpa = MethodType(callback, self)
  251. self.msa = MethodType(callback, self)
  252. self.mf = MethodType(callback, self)