edges.py 1.2 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849
  1. """Module for edges of factor graphs.
  2. This module contains classes for edges of factor graphs,
  3. which are used to build factor graphs.
  4. Classes:
  5. Edge: Class for edges.
  6. """
  7. from . import nodes
  8. class Edge:
  9. """Edge.
  10. Base class for all edges.
  11. Each edge class contains a message attribute,
  12. which stores the corresponding message in forward and backward direction.
  13. """
  14. def __init__(self, snode, tnode, init=None):
  15. """Create an edge."""
  16. # Array Index
  17. self.index = {snode: 0, tnode: 1}
  18. # Two-dimensional message list
  19. self.message = [[None, init],
  20. [init, None]]
  21. # Variable node
  22. if snode.type == nodes.NodeType.variable_node:
  23. self.variable = snode
  24. else:
  25. self.variable = tnode
  26. def __str__(self):
  27. """Return string representation."""
  28. return str(self.message)
  29. def set_message(self, snode, tnode, value):
  30. """Set value of message from source node to target node."""
  31. self.message[self.index[snode]][self.index[tnode]] = value
  32. def get_message(self, snode, tnode):
  33. """Return value of message from source node to target node."""
  34. return self.message[self.index[snode]][self.index[tnode]]