Skip to content

Commit

Permalink
Update post_hoc_plot_utils.py (#147)
Browse files Browse the repository at this point in the history
use the post_hoc_plot_utils file to create network visualizations on data that has already been trained
  • Loading branch information
aymann121 authored Dec 7, 2024
1 parent c7ca0c0 commit b76f33c
Showing 1 changed file with 170 additions and 0 deletions.
170 changes: 170 additions & 0 deletions src/utils/post_hoc_plot_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@
from typing import List, Dict, Tuple, Optional
import matplotlib.pyplot as plt
import json
import networkx as nx
import imageio
from glob import glob

# Load Logs
def load_logs(node_id: str, metric_type: str, logs_dir: str) -> pd.DataFrame:
Expand Down Expand Up @@ -407,6 +410,148 @@ def plot_metric_per_realtime(metric_df: pd.DataFrame, time_ticks: np.ndarray, me
plt.savefig(f'{output_dir}{metric_name}_per_time.png')
plt.close()



def create_weighted_images(neighbors, output_dir: str, pos):
"""
Create the images for the network visualization.
Parameters:
- neighbors: 3d numpy array of neighbors for each node x - round, y - node, z - neighbors
"""
#create a network x graph and visualize it for each round

freq = np.zeros((neighbors.shape[1], neighbors.shape[1]))
for round in range(neighbors.shape[0]):

for node in range(neighbors.shape[1]):
for neighbor in neighbors[round][node]:
freq[node][neighbor-1] += 1

# Create the directed graph
graph = nx.DiGraph()
#add edges based on which edges in freq are greater than 0 and use that as the weight
for i in range(neighbors.shape[1]):
for j in range(neighbors.shape[1]):
if freq[i][j] > 0:
graph.add_edge(i + 1, j + 1, weight=3 * freq[i][j])




# draw nodes
nx.draw_networkx_nodes(graph, pos, node_size=700, node_color="skyblue")
nx.draw_networkx_labels(graph, pos, font_size=8, font_weight="bold")

#make opposite edges not overlap by adding curvature and make edges thicker based on frequency
curvatureDict = {}
for _, (u, v) in enumerate(graph.edges()):
# make sure u v and v u always have different curvature
if (u,v) not in curvatureDict:
curvatureDict[(u,v)] = 0.1
curvatureDict[(v,u)] = 0.1

rad = curvatureDict[(u,v)]
nx.draw_networkx_edges(
graph,
pos,
edgelist=[(u, v)],
connectionstyle=f"arc3,rad={rad}",
width=freq[u-1][v-1]/3,
arrows=True,
arrowsize=20
)

# Create the image
plt.title(f"Round {round + 1}")
plt.savefig(f"{output_dir}/weighted_graph_{round + 1}.png")

plt.close()

def create_images(neighbors, output_dir: str, pos):
"""
Create the images for the network visualization.
Parameters:
- neighbors: 3d numpy array of neighbors for each node x - round, y - node, z - neighbors
"""
#create a network x graph and visualize it for each round
for round in range(neighbors.shape[0]):

# Create the directed graph
graph = nx.DiGraph()
for node in range(neighbors.shape[1]):
for neighbor in neighbors[round][node]:
graph.add_edge(node + 1, neighbor)

# draw nodes
nx.draw_networkx_nodes(graph, pos, node_size=700, node_color="skyblue")
nx.draw_networkx_labels(graph, pos, font_size=8, font_weight="bold")


#make opposite edges not overlap by adding curvature
curvatureDict = {}
for i, (u, v) in enumerate(graph.edges()):
# make sure u v and v u always have different curvature
if (u,v) not in curvatureDict:
curvatureDict[(u,v)] = 0.1
curvatureDict[(v,u)] = 0.1

rad = curvatureDict[(u,v)]
nx.draw_networkx_edges(
graph,
pos,
edgelist=[(u, v)],
connectionstyle=f"arc3,rad={rad}",
arrows=True,
arrowsize=20
)

# Create the image
plt.title(f"Round {round + 1}")
plt.savefig(f"{output_dir}/graph_{round + 1}.png")
plt.close()

def create_video(output_dir: str, image_name: str):
"""Create a gif from the images."""
images = []
for filename in sorted(glob(f"{output_dir}/{image_name}_*.png")):
images.append(imageio.imread(filename))
imageio.mimsave(f"{output_dir}/{image_name}_video.gif", images, fps = 1, loop = 0)



def create_heatmap(neighbors, output_dir: str):
"""
Create a heatmap of the edge frequency.
Parameters:
- neighbors: 3d numpy array of neighbors for each node x - round, y - node, z - neighbors
"""

# Initialize the edge frequency matrix
edge_frequency_matrix = np.zeros((neighbors.shape[1]+1, neighbors.shape[1]+1))
# Iterate over all the rounds
for round in range(neighbors.shape[0]):
# Iterate over all the nodes
for node in range(neighbors.shape[1]):
# Iterate over all the
for neighbor in neighbors[round][node]:
edge_frequency_matrix[node+1][neighbor] += 1

edge_frequency_matrix = np.log(edge_frequency_matrix + 1) # Log scale for better visualization
# Create the heatmap
plt.figure(figsize=(10, 6))
plt.imshow(edge_frequency_matrix, cmap="hot", interpolation="nearest")
plt.title("Edge Frequency Matrix")
plt.colorbar(label="Frequency of Communication")
plt.xlabel("Node")
plt.ylabel("Node")
plt.xticks(range(1,neighbors.shape[1]+1))
plt.yticks(range(1,neighbors.shape[1]+1))
plt.savefig(f"{output_dir}/edge_frequency_heatmap.png")
plt.close()

def plot_all_metrics(logs_dir: str, per_round: bool = True, per_time: bool = True, metrics_map: Optional[Dict[str, str]] = None, plot_avg_only: bool=False, **kwargs) -> None:
"""Generates plots for all metrics over rounds with aggregation."""
if metrics_map is None:
Expand Down Expand Up @@ -449,9 +594,34 @@ def plot_all_metrics(logs_dir: str, per_round: bool = True, per_time: bool = Tru
plot_avg_only=plot_avg_only,
**kwargs
)

neighbors = aggregate_neighbors_across_users(logs_dir)
# create_heatmap(neighbors, f'{os.path.dirname(logs_dir)}/plots/')
pos = nx.spring_layout(nx.DiGraph({i+1: [] for i in range(neighbors.shape[1])}))
create_images(neighbors, f'{os.path.dirname(logs_dir)}/plots/', pos)
create_weighted_images(neighbors, f'{os.path.dirname(logs_dir)}/plots/', pos)
create_video(f'{os.path.dirname(logs_dir)}/plots/', 'graph')
create_video(f'{os.path.dirname(logs_dir)}/plots/', 'weighted_graph')
create_heatmap(neighbors, f'{os.path.dirname(logs_dir)}/plots/')


print("Plots saved as PNG files.")

def aggregate_neighbors_across_users(logs_dir: str) -> np.ndarray:
"""Aggregate the neighbors of each node across all users."""
nodes = get_all_nodes(logs_dir)
nodes.sort() # Sort the nodes to ensure consistent order

all_users_neighbors = []

for node in nodes:
node_id = node.split('_')[-1]
neighbors_file = os.path.join(logs_dir, f'node_{node_id}/csv/neighbors.csv')
neighbors = pd.read_csv(neighbors_file)
np.array(all_users_neighbors.append(neighbors['neighbors'].apply(json.loads).values))

return np.array(all_users_neighbors).T

# Use if you a specific experiment folder
# if __name__ == "__main__":
# # Define the path where your experiment logs are saved
Expand Down

0 comments on commit b76f33c

Please sign in to comment.