123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081 |
- #!/usr/bin/env python
- """A simple example of the sum-product algorithm
- This is a simple example of the sum-product algorithm on a factor graph
- with Bernoulli random variables, which is taken from page 409 of the book
- C. M. Bishop, Pattern Recognition and Machine Learning. Springer, 2006.
- /--\ +----+ /--\ +----+ /--\
- | x1 |-----| fa |-----| x2 |-----| fb |-----| x3 |
- \--/ +----+ \--/ +----+ \--/
- |
- +----+
- | fc |
- +----+
- |
- /--\
- | x4 |
- \--/
- The following joint distributions are used for the factor nodes.
- fa | x2=0 x2=1 fb | x3=0 x3=1 fc | x4=0 x4=1
- ---------------- ---------------- ----------------
- x1=0 | 0.3 0.4 x2=0 | 0.3 0.4 x2=0 | 0.3 0.4
- x1=1 | 0.3 0.0 x2=1 | 0.3 0.0 x2=1 | 0.3 0.0
- """
- from fglib import graphs, nodes, inference, rv
- # Create factor graph
- fg = graphs.FactorGraph()
- # Create variable nodes
- x1 = nodes.VNode("x1", rv.Discrete)
- x2 = nodes.VNode("x2", rv.Discrete)
- x3 = nodes.VNode("x3", rv.Discrete)
- x4 = nodes.VNode("x4", rv.Discrete)
- # Create factor nodes (with joint distributions)
- dist_fa = [[0.3, 0.4],
- [0.3, 0.0]]
- fa = nodes.FNode("fa", rv.Discrete(dist_fa, x1, x2))
- dist_fb = [[0.3, 0.4],
- [0.3, 0.0]]
- fb = nodes.FNode("fb", rv.Discrete(dist_fb, x2, x3))
- dist_fc = [[0.3, 0.4],
- [0.3, 0.0]]
- fc = nodes.FNode("fc", rv.Discrete(dist_fc, x2, x4))
- # Add nodes to factor graph
- fg.set_nodes([x1, x2, x3, x4])
- fg.set_nodes([fa, fb, fc])
- # Add edges to factor graph
- fg.set_edge(x1, fa)
- fg.set_edge(fa, x2)
- fg.set_edge(x2, fb)
- fg.set_edge(fb, x3)
- fg.set_edge(x2, fc)
- fg.set_edge(fc, x4)
- # Perform sum-product algorithm on factor graph
- # and request belief of variable node x4
- belief = inference.sum_product(fg, x4)
- # Print belief of variables
- print("Belief of variable node x4:")
- print(belief)
- print("Belief of variable node x3:")
- print(x3.belief())
- print("Belief of variable node x2:")
- print(x2.belief(normalize=True))
- print("Unnormalized belief of variable node x1:")
- print(x1.belief(normalize=False))
|