utils.py 2.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172
  1. """Module for utilities.
  2. This module contains auxiliary functions for factor graphs.
  3. Functions:
  4. draw: Draw a factor graph with nodes, edges and labels.
  5. draw_message: Draw messages of a factor graph.
  6. draw_attribute: Draw node attributes of a factor graph.
  7. """
  8. import matplotlib.pyplot as plt
  9. import networkx as nx
  10. def draw(graph, pos=None):
  11. """Draw factor graph and return used positions for nodes."""
  12. if pos is None:
  13. pos = nx.spring_layout(graph)
  14. # Draw variable nodes
  15. vn = [n for (n, d) in graph.nodes(data=True) if d['type'] == "vn"]
  16. vn_observed = [n for n in vn if n.observed]
  17. nx.draw_networkx_nodes(graph, pos, nodelist=vn_observed, node_size=1000,
  18. node_color="gray", node_shape='o')
  19. vn_hidden = [n for n in vn if not n.observed]
  20. nx.draw_networkx_nodes(graph, pos, nodelist=vn_hidden, node_size=1000,
  21. node_color="white", node_shape='o')
  22. # Draw factor nodes
  23. fn = [n for (n, d) in graph.nodes(data=True) if d['type'] == "fn"]
  24. nx.draw_networkx_nodes(graph, pos, nodelist=fn, node_size=1500,
  25. node_color="white", node_shape='s')
  26. # Draw labels
  27. nx.draw_networkx_labels(graph, pos, font_size=22, font_family='sans-serif')
  28. # Draw edges
  29. nx.draw_networkx_edges(graph, pos, alpha=0.5, edge_color='black')
  30. return pos
  31. def draw_message(graph, pos):
  32. """Draw messages of a factor graph."""
  33. msg = {} # Dict of node tuples to edge labels: {(nodeX, nodeY): aString}
  34. for u, v in graph.edges():
  35. m = graph.get_edge_data(u, v)["object"]
  36. s = "$m_{" + str(u).replace('$', '') + " -> " + \
  37. str(v).replace('$', '') + "}$ = " + str(m.get_message(u, v))
  38. s = s + "\n\n"
  39. s = s + "$m_{" + str(v).replace('$', '') + " -> " + \
  40. str(u).replace('$', '') + "}$ = " + str(m.get_message(v, u))
  41. msg[(u, v)] = s
  42. bbox_props = dict(boxstyle='round',
  43. alpha=0.5,
  44. ec="none",
  45. fc="none")
  46. nx.draw_networkx_edge_labels(graph, pos, font_size=12, edge_labels=msg,
  47. bbox=bbox_props) # draw the edge labels
  48. def draw_attribute(graph, pos, attr):
  49. """Draw node attributes of a factor graph."""
  50. labels = dict((n, d[attr]) for n, d in graph.nodes(data=True) if attr in d)
  51. for n, d in labels.items():
  52. x, y = pos[n]
  53. plt.text(x, y - 0.1, s="%s = %s" % (attr, d),
  54. bbox=dict(facecolor='red', alpha=0.5),
  55. horizontalalignment='center')