inference.py 5.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207
  1. """Module for inference algorithms.
  2. This module contains different functions to perform inference on factor graphs.
  3. Functions:
  4. belief_propagation: Belief propagation
  5. sum_product: Sum-product algorithm
  6. max_product: Max-product algorithm
  7. max_sum: Max-sum algorithm
  8. loopy_belief_propagation: Loopy belief propagation
  9. mean_field: Mean-field algorithm
  10. """
  11. from random import choice
  12. import networkx as nx
  13. from . import nodes
  14. def belief_propagation(graph, query_node=None):
  15. """Belief propagation.
  16. Perform exact inference on tree structured graphs.
  17. Return the belief of all query_nodes.
  18. """
  19. if query_node is None: # pick random node
  20. query_node = choice(graph.get_vnodes())
  21. # Depth First Search to determine edges
  22. dfs = nx.dfs_edges(graph, query_node)
  23. # Convert tuple to reversed list
  24. backward_path = list(dfs)
  25. forward_path = reversed(backward_path)
  26. # Messages in forward phase
  27. for (v, u) in forward_path: # Edge direction: u -> v
  28. msg = u.spa(v)
  29. graph[u][v]['object'].set_message(u, v, msg)
  30. # Messages in backward phase
  31. for (u, v) in backward_path: # Edge direction: u -> v
  32. msg = u.spa(v)
  33. graph[u][v]['object'].set_message(u, v, msg)
  34. # Return marginal distribution
  35. return query_node.belief()
  36. def sum_product(graph, query_node=None):
  37. """Sum-product algorithm.
  38. Compute marginal distribution on graphs that are tree structured.
  39. Return the belief of all query_nodes.
  40. """
  41. # Sum-Product algorithm is equivalent to Belief Propagation
  42. return belief_propagation(graph, query_node)
  43. def max_product(graph, query_node=None):
  44. """Max-product algorithm.
  45. Compute setting of variables with maximum probability on graphs
  46. that are tree structured.
  47. Return the setting of all query_nodes.
  48. """
  49. track = {} # Setting of variables
  50. if query_node is None: # pick random node
  51. query_node = choice(graph.get_vnodes())
  52. # Depth First Search to determine edges
  53. dfs = nx.dfs_edges(graph, query_node)
  54. # Convert tuple to reversed list
  55. backward_path = list(dfs)
  56. forward_path = reversed(backward_path)
  57. # Messages in forward phase
  58. for (v, u) in forward_path: # Edge direction: u -> v
  59. msg = u.mpa(v)
  60. graph[u][v]['object'].set_message(u, v, msg)
  61. # Messages in backward phase
  62. for (u, v) in backward_path: # Edge direction: u -> v
  63. msg = u.mpa(v)
  64. graph[u][v]['object'].set_message(u, v, msg)
  65. # Maximum argument for query node
  66. track[query_node] = query_node.argmax()
  67. # Back-tracking
  68. for (u, v) in backward_path: # Edge direction: u -> v
  69. if v.type == nodes.NodeType.factor_node:
  70. for k in v.record[u].keys(): # Iterate over outgoing edges
  71. track[k] = v.record[u][k]
  72. # Return maximum probability for query node and setting of variable
  73. return query_node.maximum(), track
  74. def max_sum(graph, query_node=None):
  75. """Max-sum algorithm.
  76. Compute setting of variable for maximum probability on graphs
  77. that are tree structured.
  78. Return the setting of all query_nodes.
  79. """
  80. track = {} # Setting of variables
  81. if query_node is None: # pick random node
  82. query_node = choice(graph.get_vnodes())
  83. # Depth First Search to determine edges
  84. dfs = nx.dfs_edges(graph, query_node)
  85. # Convert tuple to reversed list
  86. backward_path = list(dfs)
  87. forward_path = reversed(backward_path)
  88. # Messages in forward phase
  89. for (v, u) in forward_path: # Edge direction: u -> v
  90. msg = u.msa(v)
  91. graph[u][v]['object'].set_message(u, v, msg)
  92. # Messages in backward phase
  93. for (u, v) in backward_path: # Edge direction: u -> v
  94. msg = u.msa(v)
  95. graph[u][v]['object'].set_message(u, v, msg)
  96. # Maximum argument for query node
  97. track[query_node] = query_node.argmax()
  98. # Back-tracking
  99. for (u, v) in backward_path: # Edge direction: u -> v
  100. if v.type == nodes.NodeType.factor_node:
  101. for k in v.record[u].keys(): # Iterate over outgoing edges
  102. track[k] = v.record[u][k]
  103. # Return maximum probability for query node and setting of variable
  104. return query_node.maximum(), track
  105. def loopy_belief_propagation(model, iterations, query_node=(), order=None):
  106. """Loopy belief propagation.
  107. Perform approximative inference on arbitrary structured graphs.
  108. Return the belief of all query_nodes.
  109. """
  110. if order is None:
  111. fn = [n for (n, attr) in model.nodes(data=True)
  112. if attr["type"] == "fn"]
  113. vn = [n for (n, attr) in model.nodes(data=True)
  114. if attr["type"] == "vn"]
  115. order = fn + vn
  116. return _schedule(model, 'spa', iterations, query_node, order)
  117. def mean_field(model, iterations, query_node=(), order=None):
  118. """Mean-field algorithm.
  119. Perform approximative inference on arbitrary structured graphs.
  120. Return the belief of all query_nodes.
  121. """
  122. if order is None:
  123. fn = [n for (n, attr) in model.nodes(data=True)
  124. if attr["type"] == "fn"]
  125. vn = [n for (n, attr) in model.nodes(data=True)
  126. if attr["type"] == "vn"]
  127. order = fn + vn
  128. return _schedule(model, 'mf', iterations, query_node, order)
  129. def _schedule(model, method, iterations, query_node, order):
  130. """Flooding schedule.
  131. A flooding scheduler for factor graphs with cycles.
  132. A given number of iterations is performed in a defined node order.
  133. Return the belief of all query_nodes.
  134. """
  135. b = {n: [] for n in query_node}
  136. # Iterative message passing
  137. for _ in range(iterations):
  138. # Visit nodes in predefined order
  139. for n in order:
  140. for neighbor in nx.all_neighbors(model, n):
  141. msg = getattr(n, method)(model, neighbor)
  142. model[n][neighbor]['object'].set_message(n, neighbor, msg)
  143. # Beliefs of query nodes
  144. for n in query_node:
  145. b[n].append(n.belief(model))
  146. return b