1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283 |
- #!/usr/bin/env python
- """A simple example of the max-product algorithm
- This is a simple example of the max-product algorithm on a factor graph
- with Discrete random variables.
- /--\ +----+ /--\ +----+ /--\
- | x1 |-----| fa |-----| x2 |-----| fb |-----| x3 |
- \--/ +----+ \--/ +----+ \--/
- |
- +----+
- | fc |
- +----+
- |
- /--\
- | x4 |
- \--/
- The following joint distributions are used for the factor nodes.
- fa | x2=0 x2=1 x2=2 fb | x3=0 x3=1 fc | x4=0 x4=1
- --------------------- ---------------- ----------------
- x1=0 | 0.3 0.2 0.1 x2=0 | 0.3 0.2 x2=0 | 0.3 0.2
- x1=1 | 0.3 0.0 0.1 x2=1 | 0.3 0.0 x2=1 | 0.3 0.0
- x2=2 | 0.1 0.1 x2=2 | 0.1 0.1
- """
- from fglib import graphs, nodes, inference, rv
- # Create factor graph
- fg = graphs.FactorGraph()
- # Create variable nodes
- x1 = nodes.VNode("x1", rv.Discrete) # Discrete random variable with 2 states (Bernoulli random variable)
- x2 = nodes.VNode("x2", rv.Discrete) # Discrete random variable with 3 states
- x3 = nodes.VNode("x3", rv.Discrete)
- x4 = nodes.VNode("x4", rv.Discrete)
- # Create factor nodes (with joint distributions)
- dist_fa = [[0.3, 0.2, 0.1],
- [0.3, 0.0, 0.1]]
- fa = nodes.FNode("fa", rv.Discrete(dist_fa, x1, x2))
- dist_fb = [[0.3, 0.2],
- [0.3, 0.0],
- [0.1, 0.1]]
- fb = nodes.FNode("fb", rv.Discrete(dist_fb, x2, x3))
- dist_fc = [[0.3, 0.2],
- [0.3, 0.0],
- [0.1, 0.1]]
- fc = nodes.FNode("fc", rv.Discrete(dist_fc, x2, x4))
- # Add nodes to factor graph
- fg.set_nodes([x1, x2, x3, x4])
- fg.set_nodes([fa, fb, fc])
- # Add edges to factor graph
- fg.set_edge(x1, fa)
- fg.set_edge(fa, x2)
- fg.set_edge(x2, fb)
- fg.set_edge(fb, x3)
- fg.set_edge(x2, fc)
- fg.set_edge(fc, x4)
- # Perform max-product algorithm on factor graph
- # and request maximum of variable node x4
- maximum, _ = inference.max_product(fg, x4)
- # Print maximum of variables
- print("Maximum of variable node x4:")
- print(maximum)
- print("Maximum of variable node x3:")
- print(x3.maximum())
- print("Maximum of variable node x2:")
- print(x2.maximum(normalize=True))
- print("Unnormalized Maximum of variable node x1:")
- print(x1.maximum(normalize=False))
|