Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

More Plotting Utils: Realtime Plots, Average Only and CI plots #140

Merged
merged 2 commits into from
Nov 11, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
260 changes: 235 additions & 25 deletions src/utils/post_hoc_plot_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,37 +183,231 @@ def aggregate_per_round_data(logs_dir: str, metrics_map: Optional[Dict[str, str]
return all_users_data

# Plotting
def plot_metric_per_round(metric_df: pd.DataFrame, rounds: np.ndarray, metric_name: str, ylabel: str, output_dir: str) -> None:
def plot_metric_per_round(metric_df: pd.DataFrame, rounds: np.ndarray, metric_name: str, ylabel: str, output_dir: str, plot_avg_only: bool = False) -> None:
"""Plot per-round data for each user and aggregate (mean and std)."""
plt.figure(figsize=(10, 6))
plt.figure(figsize=(12, 8), dpi=300)

# Plot per-user data
for col in metric_df.columns:
plt.plot(rounds, metric_df[col], alpha=0.6, label=f'User {col+1}')
# Plot per-user data if plot_avg_only is False
if not plot_avg_only:
for col in metric_df.columns:
plt.plot(rounds, metric_df[col], alpha=0.6)

# Compute mean and std
# Compute mean and std and 95% confidence interval
mean_metric = metric_df.mean(axis=1)
std_metric = metric_df.std(axis=1)
n = metric_df.shape[1]
ci_95 = 1.96 * (std_metric / np.sqrt(n))

# Save the mean and std
if not os.path.exists(output_dir):
os.makedirs(output_dir)
mean_metric.to_csv(f'{output_dir}{metric_name}_avg.csv', index=False)
std_metric.to_csv(f'{output_dir}{metric_name}_std.csv', index=False)
ci_95.to_csv(f'{output_dir}{metric_name}_ci95.csv', index=False)


# Plot the mean with standard deviation as a shaded area
plt.plot(rounds, mean_metric, label='Average', color='black', linestyle='--')
plt.fill_between(rounds, mean_metric - std_metric, mean_metric + std_metric, color='gray', alpha=0.2, label='Std dev')
# plt.fill_between(rounds, mean_metric - std_metric, mean_metric + std_metric, color='gray', alpha=0.2, label='Std dev')
plt.fill_between(rounds, mean_metric - ci_95, mean_metric + ci_95, color='gray', alpha=0.2, label='95% CI')

# Set labels, title, and add grid for better readability
plt.xlabel('Rounds (Iterations)', fontsize=14)
plt.ylabel(ylabel, fontsize=14)
plt.title(f'{ylabel} per User and Aggregate', fontsize=16)
plt.grid(True, linestyle='--', alpha=0.5)
plt.legend(fontsize=12)

# Save the plot
plt.savefig(f'{output_dir}{metric_name}_per_round.png', bbox_inches='tight')
plt.close()

def compute_per_user_realtime_data(
node_id: str,
logs_dir: str,
time_interval: int,
num_ticks: Optional[int] = None
) -> Dict[str, pd.Series]:
"""
Optimized computation of per-user real-time data based on elapsed time and logged metrics per round.

Args:
node_id (str): ID of the node (user).
logs_dir (str): Directory path where logs are stored.
time_interval (int): Interval in seconds for each tick in the real-time plot.
num_ticks (Optional[int]): Total number of ticks to fill. If specified, fills remaining ticks with last known value.

Returns:
Dict[str, pd.Series]: A dictionary with real-time metrics Series for each metric, indexed by time.
"""
# Load time elapsed data
time_data = load_logs(node_id, 'time_elapsed', logs_dir)
round_times = time_data['time_elapsed'].values
rounds = time_data['iteration'].values

# Compute per-round data for the metrics
per_round_data = compute_per_user_round_data(node_id, logs_dir)

# Initialize per_time_data for each metric
per_time_data = {key: [] for key in per_round_data.keys() if key != 'iteration'}

# Determine maximum time based on the final value in round_times and calculate the number of ticks
max_time = round_times[-1] if len(round_times) > 0 else 0
calculated_ticks = int(max_time // time_interval + 1)
total_ticks = num_ticks if num_ticks is not None else calculated_ticks

# Initialize a pointer for the current round index
round_idx = 0
time_ticks = [tick * time_interval for tick in range(1, total_ticks + 1)]

# Loop through each time tick based on the time_interval
for current_time in time_ticks:
# Move round_idx forward until round_times[round_idx] > current_time
while round_idx < len(round_times) and round_times[round_idx] <= current_time:
round_idx += 1
# Use the last valid round's metrics for the current tick
latest_round_idx = round_idx - 1 if round_idx > 0 else None

for key in per_time_data.keys():
if latest_round_idx is not None:
per_time_data[key].append(per_round_data[key][latest_round_idx])
else:
per_time_data[key].append(np.nan) # Start with NaN if no valid data exists initially

# Fill remaining ticks with the last known value for each metric
for key in per_time_data.keys():
if per_time_data[key]: # Check if there’s any data collected
last_value = per_time_data[key][-1]
per_time_data[key].extend([last_value] * (total_ticks - len(per_time_data[key])))

# Convert lists to Series with time_ticks as the index
per_time_data = {
key: pd.Series(data=values, index=time_ticks) for key, values in per_time_data.items()
}

return per_time_data


def aggregate_per_realtime_data(
logs_dir: str,
metrics_map: Optional[Dict[str, str]] = None,
time_interval: Optional[int] = None,
num_ticks: Optional[int] = 200,
) -> Dict[str, pd.DataFrame]:
"""
Aggregate the per-time data for all users.

Args:
logs_dir (str): Directory path where logs are stored.
metrics_map (Optional[Dict[str, str]]): Mapping of metric names to file names.
time_interval (Optional[int]): Interval in seconds for each tick in the real-time plot.
num_ticks (Optional[int]): Number of ticks to display. Used if time_interval is not provided.

Returns:
Dict[str, pd.DataFrame]: A dictionary with real-time metrics DataFrames for each metric, indexed by time.
"""
if metrics_map is None:
metrics_map = {
'train_acc': 'train_acc',
'test_acc': 'test_acc',
'train_loss': 'train_loss',
'test_loss': 'test_loss',
}

plt.xlabel('Rounds (Iterations)')
plt.ylabel(ylabel)
plt.title(f'{ylabel} per User and Aggregate')
plt.legend()
plt.savefig(f'{output_dir}{metric_name}_per_round.png')
nodes = get_all_nodes(logs_dir)

# Step 1: Determine max time_elapsed across all nodes if num_ticks is given and time_interval is None
if time_interval is None:
max_elapsed_time = 0
for node in nodes:
node_id = node.split('_')[-1]
time_data = load_logs(node_id, 'time_elapsed', logs_dir)
max_elapsed_time = max(max_elapsed_time, time_data['time_elapsed'].max())

# Calculate time_interval based on max_elapsed_time and num_ticks
time_interval = max_elapsed_time // num_ticks

# Initialize aggregated data storage
all_users_data = {metric: [] for metric in metrics_map}
time_ticks = None

# Step 2: Aggregate per-user data based on computed time_interval
for node in nodes:
node_id = node.split('_')[-1]
user_data = compute_per_user_realtime_data(node_id, logs_dir, time_interval, num_ticks=num_ticks)

# Append data from each user
for key in metrics_map:
all_users_data[key].append(user_data[key].values) # Each should be of shape (num_ticks,)

# Record time_ticks only once (they will be the same for all users)
if time_ticks is None:
time_ticks = user_data[list(metrics_map.keys())[0]].index.values

# write all user data as a file to check
with open('./all_users_data.csv', 'w') as f:
pd.DataFrame(all_users_data).to_csv(f)

# Convert lists of arrays into DataFrames for each metric
aggregated_data = {
key: pd.DataFrame(np.stack(all_users_data[key], axis=1), index=time_ticks)
for key in metrics_map
}

return aggregated_data


def plot_metric_per_realtime(metric_df: pd.DataFrame, time_ticks: np.ndarray, metric_name: str, ylabel: str, output_dir: str, plot_avg_only: bool = False) -> None:
"""
Plot per-time elapsed data for each user and aggregate (mean and std).

Args:
metric_df (pd.DataFrame): DataFrame containing the metric data for each user (one column per user).
time_ticks (np.ndarray): Array of time elapsed values for each tick.
metric_name (str): Name of the metric (e.g., 'train_acc', 'test_loss').
ylabel (str): Label for the y-axis of the plot.
output_dir (str): Directory to save the plot and CSV files.
"""
plt.figure(figsize=(12, 8), dpi=300)

# Plot per-user data
if not plot_avg_only:
for col in metric_df.columns:
# plt.plot(time_ticks, metric_df[col], alpha=0.6, label=f'User {col+1}')
plt.plot(time_ticks, metric_df[col], alpha=0.6)

# Compute mean and std and 95% confidence interval
mean_metric = metric_df.mean(axis=1)
std_metric = metric_df.std(axis=1)
n = metric_df.shape[1]
ci_95 = 1.96 * (std_metric / np.sqrt(n))

# Ensure the output directory exists
if not os.path.exists(output_dir):
os.makedirs(output_dir)

# Save the mean and std to CSV
mean_metric.to_csv(f'{output_dir}{metric_name}_avg_per_time.csv', index=False)
std_metric.to_csv(f'{output_dir}{metric_name}_std_per_time.csv', index=False)
ci_95.to_csv(f'{output_dir}{metric_name}_ci95.csv', index=False)

# Plot the mean with standard deviation as a shaded area
plt.plot(time_ticks, mean_metric, label='Average', color='black', linestyle='--')
# plt.fill_between(time_ticks, mean_metric - std_metric, mean_metric + std_metric, color='gray', alpha=0.2, label='Std dev')
plt.fill_between(time_ticks, mean_metric - ci_95, mean_metric + ci_95, color='gray', alpha=0.2, label='95% CI')

# Set labels and title
plt.xlabel('Time Elapsed (seconds)', fontsize=14)
plt.ylabel(ylabel, fontsize=14)
plt.title(f'{ylabel} per User and Aggregate over Time Elapsed', fontsize=16)
plt.grid(True, linestyle='--', alpha=0.5)
plt.legend(fontsize=12)

# Save the plot
plt.savefig(f'{output_dir}{metric_name}_per_time.png')
plt.close()

def plot_all_metrics(logs_dir: str, metrics_map: Optional[Dict[str, str]] = None) -> None:
def plot_all_metrics(logs_dir: str, per_round: bool = True, per_time: bool = True, metrics_map: Optional[Dict[str, str]] = None, plot_avg_only: bool=False, **kwargs) -> None:
"""Generates plots for all metrics over rounds with aggregation."""
if metrics_map is None:
metrics_map = {
Expand All @@ -223,25 +417,41 @@ def plot_all_metrics(logs_dir: str, metrics_map: Optional[Dict[str, str]] = None
'train_loss': 'Train Loss'
}

all_users_data = aggregate_per_round_data(logs_dir)

for key, display_name in metrics_map.items():
plot_metric_per_round(
metric_df=all_users_data[key],
rounds=all_users_data['rounds'],
metric_name=key,
ylabel=display_name,
output_dir=f'{logs_dir}plots/'
)
if per_round:
all_users_data = aggregate_per_round_data(logs_dir, **kwargs)

for key, display_name in metrics_map.items():
plot_metric_per_round(
metric_df=all_users_data[key],
rounds=all_users_data['rounds'],
metric_name=key,
ylabel=display_name,
output_dir=f'{logs_dir}plots/',
plot_avg_only=plot_avg_only,
**kwargs
)
if per_time:
all_users_data = aggregate_per_realtime_data(logs_dir, **kwargs)

for key, display_name in metrics_map.items():
plot_metric_per_realtime(
metric_df=all_users_data[key],
time_ticks=all_users_data[key].index.values,
metric_name=key,
ylabel=display_name,
output_dir=f'{logs_dir}plots/',
plot_avg_only=plot_avg_only,
**kwargs
)

print("Plots saved as PNG files.")

# Use if you a specific experiment folder
# if __name__ == "__main__":
# # Define the path where your experiment logs are saved
# logs_dir = '/mas/camera/Experiments/SONAR/abhi/cifar10_36users_1250_convergence_ringm3_seed2/logs/'
# logs_dir = '/mas/camera/Experiments/SONAR/jyuan/experiment/logs_sample_time_elapsed/'
# avg_metrics, std_metrics, df_metrics = aggregate_metrics_across_users(logs_dir)
# plot_all_metrics(logs_dir)
# plot_all_metrics(logs_dir, per_round=True, per_time=True, plot_avg_only=True)


# Use if you want to compute for multiple experiment folders
Expand Down
Loading