Skip to content

Commit

Permalink
Add dataset plots
Browse files Browse the repository at this point in the history
  • Loading branch information
robin-janssen committed Oct 8, 2024
1 parent 9b098e6 commit 14653ad
Show file tree
Hide file tree
Showing 8 changed files with 37 additions and 26 deletions.
16 changes: 8 additions & 8 deletions config.yaml
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
# Global settings for the benchmark
training_id: "delete_me5"
surrogates: ["LatentPoly", "FullyConnected", "MultiONet"]
batch_size: [256, 256, 256]
epochs: [2, 2, 2]
surrogates: ["MultiONet"]
batch_size: [256]
epochs: [10]
dataset:
name: "osu2008"
log10_transform: True
normalise: "minmax" # "standardise", "minmax", "disable"
name: "branca24"
log10_transform: False
normalise: "disable" # "minmax" # "standardise", "minmax", "disable"
use_optimal_params: True
tolerance: 1e-20
devices: ["cuda:1"]
tolerance: 1e-30
devices: ["cpu"]
seed: 42
verbose: False

Expand Down
Binary file added datasets/branca24/example_trajectories.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
26 changes: 16 additions & 10 deletions datasets/data_analysis/analyse_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,17 @@

from codes import check_and_load_data

from .data_plots import plot_example_trajectories, plot_example_trajectories_paper
from datasets.data_analysis.data_plots import (
plot_example_trajectories,
# plot_example_trajectories_paper,
)


def main(args):
"""
Main function to analyse the dataset. It checks the dataset and loads the data.
"""
log = True
# Load full data
(
full_train_data,
Expand All @@ -24,7 +28,7 @@ def main(args):
) = check_and_load_data(
args.dataset,
verbose=False,
log=True,
log=log,
normalisation_mode="disable",
)

Expand All @@ -36,21 +40,23 @@ def main(args):
num_chemicals=20,
save=True,
labels=labels,
sample_idx=7,
log=log,
)

plot_example_trajectories_paper(
args.dataset,
full_train_data,
timesteps,
save=True,
labels=labels,
)
# plot_example_trajectories_paper(
# args.dataset,
# full_train_data,
# timesteps,
# save=True,
# labels=labels,
# )


if __name__ == "__main__":
parser = ArgumentParser()
parser.add_argument(
"--dataset", default="osu2008", type=str, help="Name of the dataset."
"--dataset", default="branca24", type=str, help="Name of the dataset."
)
args = parser.parse_args()
main(args)
21 changes: 13 additions & 8 deletions datasets/data_analysis/data_plots.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ def plot_example_trajectories(
labels: list[str] | None = None,
save: bool = False,
sample_idx: int = 0,
log: bool = False,
) -> None:
"""
Plot example trajectories for the dataset.
Expand Down Expand Up @@ -48,7 +49,7 @@ def plot_example_trajectories(
gt,
"-",
color=color,
label=f"Chemical {chem_idx+1}" if labels is None else labels[chem_idx],
label=f"Quantity {chem_idx + 1}" if labels is None else labels[chem_idx],
)

# Set labels and title
Expand All @@ -57,9 +58,10 @@ def plot_example_trajectories(
# Remove all ticks
plt.tick_params(axis="x", which="both", bottom=False, top=False)
plt.tick_params(axis="y", which="both", left=False, right=False)
plt.ylabel("Chemical Abundance")
ylabel = "log(Abundance)" if log else "Abundance"
plt.ylabel(ylabel)
plt.title(f"Example Trajectories for Dataset: {dataset_name}")
plt.legend(title="Chemicals")
plt.legend(title="Quantity")
plt.grid(True)

# Save the plot if required
Expand All @@ -80,7 +82,7 @@ def plot_example_trajectories(
increase_count=False,
)

plt.show()
# plt.show()


def plot_example_trajectories_paper(
Expand All @@ -106,8 +108,11 @@ def plot_example_trajectories_paper(
data = data[sample_idx]

# Define the number of chemicals per subplot
num_chemicals_subplots = [15, 14]
total_chemicals = num_chemicals_subplots[0] + num_chemicals_subplots[1]
total_chemicals = data.shape[1]
num_chemicals_subplots = [
total_chemicals // 2,
total_chemicals - total_chemicals // 2,
]

# Ensure the labels list matches the number of chemicals
if labels is not None:
Expand All @@ -130,14 +135,14 @@ def plot_example_trajectories_paper(
for chem_idx in range(num_chemicals_subplots[0]):
color = colors1[chem_idx]
gt = data[:, chem_idx]
label = labels[chem_idx] if labels is not None else f"Chemical {chem_idx+1}"
label = labels[chem_idx] if labels is not None else f"Chemical {chem_idx + 1}"
ax1.plot(timesteps, gt, "-", color=color, label=label)

# Plot second set of chemicals on ax2
for chem_idx in range(num_chemicals_subplots[0], total_chemicals):
color = colors2[chem_idx - num_chemicals_subplots[0]]
gt = data[:, chem_idx]
label = labels[chem_idx] if labels is not None else f"Chemical {chem_idx+1}"
label = labels[chem_idx] if labels is not None else f"Chemical {chem_idx + 1}"
ax2.plot(timesteps, gt, "-", color=color, label=label)

# Set labels and title
Expand Down
Binary file added datasets/lotka_volterra/example_trajectories.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added datasets/simple_ode/example_trajectories.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.

0 comments on commit 14653ad

Please sign in to comment.