Transformer fundamentals
 
Loading...
Searching...
No Matches
graph_utils.py
Go to the documentation of this file.
1import os
2import tempfile
3import time
4
5import matplotlib.pyplot as plt
6import networkx as nx
7import numpy as np
8import pydot
9from IPython.display import Image, display
10
11
12class Graph:
13 """
14 @brief Attempting to create a custom Graph class that can load, manipulate, and visualize graphs from .dot files
15 these are basic graphs right now
16 """
17
18 def __init__(self):
19 """
20 @brief Initalize an empty graph
21
22 @param self: the current graph object
23 """
24 self.graph = nx.DiGraph()
25
26 def load_from_dot(self, file_path):
27 """
28 @brief Loads a graph from a .dot file.
29
30 @param file_path: Path to the .dot file
31
32 @return bool: True if successful, False if fails for any reason
33 """
34 try:
35 # function used to read from a dot file
36 graphs = pydot.graph_from_dot_file(file_path)
37 if not graphs:
38 print("No graphs was found in the .dot file (check the file maybe?)")
39 return False
40 # Conversion to a netwrokx graph
41 dot_graph = graphs[0]
42 if dot_graph.get_type() == "graph":
43 self.graph = nx.Graph(nx.drawing.nx_pydot.from_pydot(dot_graph))
44 else:
45 self.graph = nx.DiGraph(nx.drawing.nx_pydot.from_pydot(dot_graph))
46 print(
47 f"Success: Graph with {len(list(self.graph.nodes()))} nodes and {len(self.graph.edges())} edges"
48 )
49 return True
50
51 except Exception as e:
52 print(f"Error loading .dot file: {e}")
53 return False
54
55 def add_node(self, node_id, **attributes):
56 """
57 @brief Add a node to the current graph
58
59 @param self: the current graph object
60 @param node_id: tage for the node
61 @param **attributes: optional node features
62 """
63 self.graph.add_node(node_id, **attributes)
64
65 def add_edge(self, start, dest, **attributes):
66 """
67 @brief Add an edge to the graph
68
69 @param self: the current graph object
70 @param start: The node to start at
71 @param dest: the node to end at
72 @param **attributes: optional edge features
73 """
74 self.graph.add_edge(start, dest, **attributes)
75
76 def save_to_dot(self, file_path):
77 """
78 @brief Save the current graph to a .dot file
79
80 @param self: the current graph object
81 @param file_path: output file path
82 @return bool: True if success, false if something failed
83 """
84 try:
85 # conversion from Networkx to dot
86 pydot_graph = nx.drawing.nx_pydot.to_pydot(self.graph)
87
88 pydot_graph.write_dot(file_path)
89 print(f"Graph saved to {file_path}")
90 return True
91 except Exception as e:
92 print(f"Error saving graph to .dot file: {e}")
93 return False
94
96 self, layout="dot", format="png", show=True, output_file=None
97 ):
98 """
99 @brief Visualize with GraphViz by creating a png
100
101 @param layout: GraphViz layout type ('dot', 'neato', 'fdp', 'sfdp', 'twopi', 'circo')
102 @param format: Output format ('png', 'svg', 'pdf', etc.)
103 @param show: Display the graph (if true) or return the image data (false)
104 @return bytes or None: Image data if show=False, otherwise None
105 """
106 try:
107 # Convert NetworkX graph to pydot
108 pydot_graph = nx.drawing.nx_pydot.to_pydot(self.graph)
109
110 # Set the layout engine
111 pydot_graph.set_layout(layout)
112
113 # If output file is provided, save directly to it
114 if output_file:
115 pydot_graph.write(output_file, format=format)
116 print(f"Graph visualization saved to {output_file}")
117 return None
118
119 # Otherwise, proceed with temporary file approach for display or return
120 with tempfile.NamedTemporaryFile(suffix=f".{format}", delete=False) as tmp:
121 temp_name = tmp.name
122
123 # Render the graph to the temporary file
124 pydot_graph.write(temp_name, format=format)
125
126 if show:
127 try:
128 # Try to display the image (works in Jupyter/IPython)
129 img = Image(filename=temp_name)
130 display(img)
131 except (NameError, ImportError):
132 # If we're in a terminal, just inform the user where the file is
133 print(f"Graph saved to temporary file: {temp_name}")
134 print(
135 "(Note: Interactive display not available in terminal environment)"
136 )
137 return None
138 else:
139 # Return the image data
140 with open(temp_name, "rb") as f:
141 img_data = f.read()
142 os.unlink(temp_name) # Clean up
143 return img_data
144
145 except Exception as e:
146 print(f"Error displaying graph: {e}")
147 return None
148
150 self,
151 layout=None,
152 node_size=300,
153 node_color="skyblue",
154 edge_color="black",
155 with_labels=True,
156 font_size=10,
157 figsize=(8, 6),
158 output_file=None,
159 ):
160 """
161 @brief Trying and seeing how visualization of graphs with Matplotlib looks like.
162
163 @param layout: the layout algorithm used
164 @param node_size: size of nodes
165 @param node_color: color of nodes
166 @param edge_color: color of edges
167 @param with_labels: whether to display node labels
168 @param font_size: size of node labels
169 @param figsize: figure size
170 """
171 try:
172 # Create a new figure with the specified size
173 fig = plt.figure(figsize=figsize)
174
175 # Generate layout positions
176 if layout is None:
177 # Choose layout based on graph type
178 if nx.is_directed(self.graph):
179 pos = nx.spring_layout(self.graph)
180 else:
181 # Fallback to spring_layout if kamada_kawai fails (which can happen with certain graph structures)
182 try:
183 pos = nx.kamada_kawai_layout(self.graph)
184 except:
185 pos = nx.spring_layout(self.graph)
186 else:
187 # Use the provided layout function
188 pos = layout(self.graph)
189
190 # Draw the graph
191 nx.draw_networkx(
192 self.graph,
193 pos=pos,
194 with_labels=with_labels,
195 node_size=node_size,
196 node_color=node_color,
197 edge_color=edge_color,
198 font_size=font_size,
199 )
200
201 # Remove axis
202 plt.gca().set_axis_off()
203
204 # Adjust layout
205 plt.tight_layout()
206
207 # Save to file if requested
208 if output_file:
209 plt.savefig(output_file)
210 plt.close()
211 print(f"Matplotlib visualization saved to {output_file}")
212 else:
213 # Try to show interactively (may not work in terminal)
214 try:
215 plt.show()
216 except Exception as e:
217 # Generate a temporary file if showing fails
218 temp_file = f"graph_matplotlib_{int(time.time())}.png"
219 plt.savefig(temp_file)
220 plt.close()
221 print(f"Matplotlib visualization saved to {temp_file}")
222 print(
223 "(Note: Interactive display not available in terminal environment)"
224 )
225
226 except Exception as e:
227 print(f"Error displaying graph with matplotlib: {e}")
228 import traceback
229
230 traceback.print_exc()
231
232 def get_node_count(self):
233 """
234 @brief return the number of nodes in the graph.
235 @return num of nodes in graph
236 """
237 return len(list(self.graph.nodes()))
238
239 def get_edge_count(self):
240 """
241 @brief return the number of edges in the graph.
242 @return num of edges in graph
243 """
244 return len(self.graph.edges())
245
246 def get_nodes(self):
247 """
248 @brief return the list of nodes in the graph
249 @return list of nodes in graph
250 """
251 return list(self.graph.nodes())
252
253 def get_edges(self):
254 """
255 @brief return the list of edges in the graph
256 @return list of edges in graph
257 """
258 return list(self.graph.edges())
259
260 def get_successors(self, node):
261 """
262 @brief return successors of a node
263 @param node: the node which will be check for successor
264 @return list of successors from node
265 """
266 return list(self.graph.successors(node))
267
268 def get_predecessors(self, node):
269 """
270 @brief reutrn the predessors of a node
271 @param node: the node which predecessors will be checked from
272 @return list of predecessors from node
273 """
274 return list(self.graph.predecessors(node))
275
276 def get_adjacency_matrix(self, weight=None):
277 """
278 @brief return the adjacency matrix of the graph
279
280 @param weight: Edge data weight. If none, all adges have weight 1. If a
281 string, the edge attribute. If a function, the value
282 returned by the function is used
283 @return numpy array: the adjacency matrix of the graph. Rows and columns
284 are ordred according to the node list obtained by
285 self.get_nodes(). Non-existent edges are represented by zeroes
286 """
287 try:
288 adj_matrix = nx.to_numpy_array(
289 self.graph, nodelist=self.get_nodes(), weight=weight
290 )
291
292 return adj_matrix
293 except Exception as e:
294 print(f"Error generating adjacency matrix: {e}")
295 import traceback
296
297 traceback.print_exc()
298
299 # Fallback implementation if NetworkX's function fails
300 try:
301 # Manual implementation of adjacency matrix calculation
302 nodes = self.get_nodes()
303 n = len(nodes)
304
305 # Create a mapping from node names to indices
306 node_to_idx = {node: i for i, node in enumerate(nodes)}
307
308 # Initialize the adjacency matrix with zeros
309 adj_matrix = np.zeros((n, n))
310
311 # Fill in the adjacency matrix
312 for u, v, data in self.graph.edges(data=True):
313 i, j = node_to_idx[u], node_to_idx[v]
314
315 if weight is None:
316 # Unweighted - use 1 for all edges
317 adj_matrix[i, j] = 1
318 elif callable(weight):
319 # If weight is a function, apply it to the edge data
320 adj_matrix[i, j] = weight(data)
321 elif weight in data:
322 # If weight is a string, use it as an edge attribute key
323 adj_matrix[i, j] = float(data[weight])
324 else:
325 # If weight attribute doesn't exist, use 1
326 adj_matrix[i, j] = 1
327
328 # For undirected graphs, fill the symmetric entry
329 if not nx.is_directed(self.graph):
330 adj_matrix[j, i] = adj_matrix[i, j]
331
332 return adj_matrix
333
334 except Exception as e:
335 print(f"Fallback adjacency matrix calculation failed: {e}")
336 traceback.print_exc()
337 return None
Attempting to create a custom Graph class that can load, manipulate, and visualize graphs from ....
display_with_matplotlib(self, layout=None, node_size=300, node_color="skyblue", edge_color="black", with_labels=True, font_size=10, figsize=(8, 6), output_file=None)
Trying and seeing how visualization of graphs with Matplotlib looks like.
get_adjacency_matrix(self, weight=None)
return the adjacency matrix of the graph
add_node(self, node_id, **attributes)
Add a node to the current graph.
get_successors(self, node)
return successors of a node
get_nodes(self)
return the list of nodes in the graph
get_predecessors(self, node)
reutrn the predessors of a node
__init__(self)
Initalize an empty graph.
get_edge_count(self)
return the number of edges in the graph.
save_to_dot(self, file_path)
Save the current graph to a .dot file.
add_edge(self, start, dest, **attributes)
Add an edge to the graph.
get_edges(self)
return the list of edges in the graph
load_from_dot(self, file_path)
Loads a graph from a .dot file.
get_node_count(self)
return the number of nodes in the graph.
display_with_graphviz(self, layout="dot", format="png", show=True, output_file=None)
Visualize with GraphViz by creating a png.