example_max 2.4 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283
  1. #!/usr/bin/env python
  2. """A simple example of the max-product algorithm
  3. This is a simple example of the max-product algorithm on a factor graph
  4. with Discrete random variables.
  5. /--\ +----+ /--\ +----+ /--\
  6. | x1 |-----| fa |-----| x2 |-----| fb |-----| x3 |
  7. \--/ +----+ \--/ +----+ \--/
  8. |
  9. +----+
  10. | fc |
  11. +----+
  12. |
  13. /--\
  14. | x4 |
  15. \--/
  16. The following joint distributions are used for the factor nodes.
  17. fa | x2=0 x2=1 x2=2 fb | x3=0 x3=1 fc | x4=0 x4=1
  18. --------------------- ---------------- ----------------
  19. x1=0 | 0.3 0.2 0.1 x2=0 | 0.3 0.2 x2=0 | 0.3 0.2
  20. x1=1 | 0.3 0.0 0.1 x2=1 | 0.3 0.0 x2=1 | 0.3 0.0
  21. x2=2 | 0.1 0.1 x2=2 | 0.1 0.1
  22. """
  23. from fglib import graphs, nodes, inference, rv
  24. # Create factor graph
  25. fg = graphs.FactorGraph()
  26. # Create variable nodes
  27. x1 = nodes.VNode("x1", rv.Discrete) # Discrete random variable with 2 states (Bernoulli random variable)
  28. x2 = nodes.VNode("x2", rv.Discrete) # Discrete random variable with 3 states
  29. x3 = nodes.VNode("x3", rv.Discrete)
  30. x4 = nodes.VNode("x4", rv.Discrete)
  31. # Create factor nodes (with joint distributions)
  32. dist_fa = [[0.3, 0.2, 0.1],
  33. [0.3, 0.0, 0.1]]
  34. fa = nodes.FNode("fa", rv.Discrete(dist_fa, x1, x2))
  35. dist_fb = [[0.3, 0.2],
  36. [0.3, 0.0],
  37. [0.1, 0.1]]
  38. fb = nodes.FNode("fb", rv.Discrete(dist_fb, x2, x3))
  39. dist_fc = [[0.3, 0.2],
  40. [0.3, 0.0],
  41. [0.1, 0.1]]
  42. fc = nodes.FNode("fc", rv.Discrete(dist_fc, x2, x4))
  43. # Add nodes to factor graph
  44. fg.set_nodes([x1, x2, x3, x4])
  45. fg.set_nodes([fa, fb, fc])
  46. # Add edges to factor graph
  47. fg.set_edge(x1, fa)
  48. fg.set_edge(fa, x2)
  49. fg.set_edge(x2, fb)
  50. fg.set_edge(fb, x3)
  51. fg.set_edge(x2, fc)
  52. fg.set_edge(fc, x4)
  53. # Perform max-product algorithm on factor graph
  54. # and request maximum of variable node x4
  55. maximum, _ = inference.max_product(fg, x4)
  56. # Print maximum of variables
  57. print("Maximum of variable node x4:")
  58. print(maximum)
  59. print("Maximum of variable node x3:")
  60. print(x3.maximum())
  61. print("Maximum of variable node x2:")
  62. print(x2.maximum(normalize=True))
  63. print("Unnormalized Maximum of variable node x1:")
  64. print(x1.maximum(normalize=False))