example_spa 2.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081
  1. #!/usr/bin/env python
  2. """A simple example of the sum-product algorithm
  3. This is a simple example of the sum-product algorithm on a factor graph
  4. with Bernoulli random variables, which is taken from page 409 of the book
  5. C. M. Bishop, Pattern Recognition and Machine Learning. Springer, 2006.
  6. /--\ +----+ /--\ +----+ /--\
  7. | x1 |-----| fa |-----| x2 |-----| fb |-----| x3 |
  8. \--/ +----+ \--/ +----+ \--/
  9. |
  10. +----+
  11. | fc |
  12. +----+
  13. |
  14. /--\
  15. | x4 |
  16. \--/
  17. The following joint distributions are used for the factor nodes.
  18. fa | x2=0 x2=1 fb | x3=0 x3=1 fc | x4=0 x4=1
  19. ---------------- ---------------- ----------------
  20. x1=0 | 0.3 0.4 x2=0 | 0.3 0.4 x2=0 | 0.3 0.4
  21. x1=1 | 0.3 0.0 x2=1 | 0.3 0.0 x2=1 | 0.3 0.0
  22. """
  23. from fglib import graphs, nodes, inference, rv
  24. # Create factor graph
  25. fg = graphs.FactorGraph()
  26. # Create variable nodes
  27. x1 = nodes.VNode("x1", rv.Discrete)
  28. x2 = nodes.VNode("x2", rv.Discrete)
  29. x3 = nodes.VNode("x3", rv.Discrete)
  30. x4 = nodes.VNode("x4", rv.Discrete)
  31. # Create factor nodes (with joint distributions)
  32. dist_fa = [[0.3, 0.4],
  33. [0.3, 0.0]]
  34. fa = nodes.FNode("fa", rv.Discrete(dist_fa, x1, x2))
  35. dist_fb = [[0.3, 0.4],
  36. [0.3, 0.0]]
  37. fb = nodes.FNode("fb", rv.Discrete(dist_fb, x2, x3))
  38. dist_fc = [[0.3, 0.4],
  39. [0.3, 0.0]]
  40. fc = nodes.FNode("fc", rv.Discrete(dist_fc, x2, x4))
  41. # Add nodes to factor graph
  42. fg.set_nodes([x1, x2, x3, x4])
  43. fg.set_nodes([fa, fb, fc])
  44. # Add edges to factor graph
  45. fg.set_edge(x1, fa)
  46. fg.set_edge(fa, x2)
  47. fg.set_edge(x2, fb)
  48. fg.set_edge(fb, x3)
  49. fg.set_edge(x2, fc)
  50. fg.set_edge(fc, x4)
  51. # Perform sum-product algorithm on factor graph
  52. # and request belief of variable node x4
  53. belief = inference.sum_product(fg, x4)
  54. # Print belief of variables
  55. print("Belief of variable node x4:")
  56. print(belief)
  57. print("Belief of variable node x3:")
  58. print(x3.belief())
  59. print("Belief of variable node x2:")
  60. print(x2.belief(normalize=True))
  61. print("Unnormalized belief of variable node x1:")
  62. print(x1.belief(normalize=False))