From b76f33cdd8c606da9df0ed3c4d624fa4e4543a9e Mon Sep 17 00:00:00 2001 From: Ayman Noreldaim <84142410+aymann121@users.noreply.github.com> Date: Sat, 7 Dec 2024 02:24:40 -0500 Subject: [PATCH] Update post_hoc_plot_utils.py (#147) use the post_hoc_plot_utils file to create network visualizations on data that has already been trained --- src/utils/post_hoc_plot_utils.py | 170 +++++++++++++++++++++++++++++++ 1 file changed, 170 insertions(+) diff --git a/src/utils/post_hoc_plot_utils.py b/src/utils/post_hoc_plot_utils.py index 176414d..b53cbc5 100644 --- a/src/utils/post_hoc_plot_utils.py +++ b/src/utils/post_hoc_plot_utils.py @@ -5,6 +5,9 @@ from typing import List, Dict, Tuple, Optional import matplotlib.pyplot as plt import json +import networkx as nx +import imageio +from glob import glob # Load Logs def load_logs(node_id: str, metric_type: str, logs_dir: str) -> pd.DataFrame: @@ -407,6 +410,148 @@ def plot_metric_per_realtime(metric_df: pd.DataFrame, time_ticks: np.ndarray, me plt.savefig(f'{output_dir}{metric_name}_per_time.png') plt.close() + + +def create_weighted_images(neighbors, output_dir: str, pos): + """ + Create the images for the network visualization. + + Parameters: + - neighbors: 3d numpy array of neighbors for each node x - round, y - node, z - neighbors + """ + #create a network x graph and visualize it for each round + + freq = np.zeros((neighbors.shape[1], neighbors.shape[1])) + for round in range(neighbors.shape[0]): + + for node in range(neighbors.shape[1]): + for neighbor in neighbors[round][node]: + freq[node][neighbor-1] += 1 + + # Create the directed graph + graph = nx.DiGraph() + #add edges based on which edges in freq are greater than 0 and use that as the weight + for i in range(neighbors.shape[1]): + for j in range(neighbors.shape[1]): + if freq[i][j] > 0: + graph.add_edge(i + 1, j + 1, weight=3 * freq[i][j]) + + + + + # draw nodes + nx.draw_networkx_nodes(graph, pos, node_size=700, node_color="skyblue") + nx.draw_networkx_labels(graph, pos, font_size=8, font_weight="bold") + + #make opposite edges not overlap by adding curvature and make edges thicker based on frequency + curvatureDict = {} + for _, (u, v) in enumerate(graph.edges()): + # make sure u v and v u always have different curvature + if (u,v) not in curvatureDict: + curvatureDict[(u,v)] = 0.1 + curvatureDict[(v,u)] = 0.1 + + rad = curvatureDict[(u,v)] + nx.draw_networkx_edges( + graph, + pos, + edgelist=[(u, v)], + connectionstyle=f"arc3,rad={rad}", + width=freq[u-1][v-1]/3, + arrows=True, + arrowsize=20 + ) + + # Create the image + plt.title(f"Round {round + 1}") + plt.savefig(f"{output_dir}/weighted_graph_{round + 1}.png") + + plt.close() + +def create_images(neighbors, output_dir: str, pos): + """ + Create the images for the network visualization. + + Parameters: + - neighbors: 3d numpy array of neighbors for each node x - round, y - node, z - neighbors + """ + #create a network x graph and visualize it for each round + for round in range(neighbors.shape[0]): + + # Create the directed graph + graph = nx.DiGraph() + for node in range(neighbors.shape[1]): + for neighbor in neighbors[round][node]: + graph.add_edge(node + 1, neighbor) + + # draw nodes + nx.draw_networkx_nodes(graph, pos, node_size=700, node_color="skyblue") + nx.draw_networkx_labels(graph, pos, font_size=8, font_weight="bold") + + + #make opposite edges not overlap by adding curvature + curvatureDict = {} + for i, (u, v) in enumerate(graph.edges()): + # make sure u v and v u always have different curvature + if (u,v) not in curvatureDict: + curvatureDict[(u,v)] = 0.1 + curvatureDict[(v,u)] = 0.1 + + rad = curvatureDict[(u,v)] + nx.draw_networkx_edges( + graph, + pos, + edgelist=[(u, v)], + connectionstyle=f"arc3,rad={rad}", + arrows=True, + arrowsize=20 + ) + + # Create the image + plt.title(f"Round {round + 1}") + plt.savefig(f"{output_dir}/graph_{round + 1}.png") + plt.close() + +def create_video(output_dir: str, image_name: str): + """Create a gif from the images.""" + images = [] + for filename in sorted(glob(f"{output_dir}/{image_name}_*.png")): + images.append(imageio.imread(filename)) + imageio.mimsave(f"{output_dir}/{image_name}_video.gif", images, fps = 1, loop = 0) + + + +def create_heatmap(neighbors, output_dir: str): + """ + Create a heatmap of the edge frequency. + + Parameters: + - neighbors: 3d numpy array of neighbors for each node x - round, y - node, z - neighbors + """ + + # Initialize the edge frequency matrix + edge_frequency_matrix = np.zeros((neighbors.shape[1]+1, neighbors.shape[1]+1)) + # Iterate over all the rounds + for round in range(neighbors.shape[0]): + # Iterate over all the nodes + for node in range(neighbors.shape[1]): + # Iterate over all the + for neighbor in neighbors[round][node]: + edge_frequency_matrix[node+1][neighbor] += 1 + + edge_frequency_matrix = np.log(edge_frequency_matrix + 1) # Log scale for better visualization + # Create the heatmap + plt.figure(figsize=(10, 6)) + plt.imshow(edge_frequency_matrix, cmap="hot", interpolation="nearest") + plt.title("Edge Frequency Matrix") + plt.colorbar(label="Frequency of Communication") + plt.xlabel("Node") + plt.ylabel("Node") + plt.xticks(range(1,neighbors.shape[1]+1)) + plt.yticks(range(1,neighbors.shape[1]+1)) + plt.savefig(f"{output_dir}/edge_frequency_heatmap.png") + plt.close() + def plot_all_metrics(logs_dir: str, per_round: bool = True, per_time: bool = True, metrics_map: Optional[Dict[str, str]] = None, plot_avg_only: bool=False, **kwargs) -> None: """Generates plots for all metrics over rounds with aggregation.""" if metrics_map is None: @@ -449,9 +594,34 @@ def plot_all_metrics(logs_dir: str, per_round: bool = True, per_time: bool = Tru plot_avg_only=plot_avg_only, **kwargs ) + + neighbors = aggregate_neighbors_across_users(logs_dir) + # create_heatmap(neighbors, f'{os.path.dirname(logs_dir)}/plots/') + pos = nx.spring_layout(nx.DiGraph({i+1: [] for i in range(neighbors.shape[1])})) + create_images(neighbors, f'{os.path.dirname(logs_dir)}/plots/', pos) + create_weighted_images(neighbors, f'{os.path.dirname(logs_dir)}/plots/', pos) + create_video(f'{os.path.dirname(logs_dir)}/plots/', 'graph') + create_video(f'{os.path.dirname(logs_dir)}/plots/', 'weighted_graph') + create_heatmap(neighbors, f'{os.path.dirname(logs_dir)}/plots/') + print("Plots saved as PNG files.") +def aggregate_neighbors_across_users(logs_dir: str) -> np.ndarray: + """Aggregate the neighbors of each node across all users.""" + nodes = get_all_nodes(logs_dir) + nodes.sort() # Sort the nodes to ensure consistent order + + all_users_neighbors = [] + + for node in nodes: + node_id = node.split('_')[-1] + neighbors_file = os.path.join(logs_dir, f'node_{node_id}/csv/neighbors.csv') + neighbors = pd.read_csv(neighbors_file) + np.array(all_users_neighbors.append(neighbors['neighbors'].apply(json.loads).values)) + + return np.array(all_users_neighbors).T + # Use if you a specific experiment folder # if __name__ == "__main__": # # Define the path where your experiment logs are saved