From 839e1c9224f044e49e8976a76b1c4c93283ed867 Mon Sep 17 00:00:00 2001 From: Clay Date: Thu, 1 Feb 2024 12:19:01 -0500 Subject: [PATCH 1/5] generalize cpu count to windows and mac --- interpretability/task_modeling/task_train_prep.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/interpretability/task_modeling/task_train_prep.py b/interpretability/task_modeling/task_train_prep.py index 62ea092..dc47341 100644 --- a/interpretability/task_modeling/task_train_prep.py +++ b/interpretability/task_modeling/task_train_prep.py @@ -7,6 +7,7 @@ import hydra import pytorch_lightning as pl from gymnasium import Env +import platform from interpretability.task_modeling.simulator.neural_simulator import ( NeuralDataSimulator, @@ -138,7 +139,10 @@ def train( accelerator="auto", _convert_="all", ) - print(len(os.sched_getaffinity(0))) + if platform.system() == 'Windows' or platform.system() == 'Darwin': # Darwin indicates MacOS + print(os.cpu_count()) + else: + print(len(os.sched_getaffinity(0))) # -----------------------------Train model--------------------------- log.info("Training model") trainer.fit(model=task_wrapper, datamodule=datamodule) From 6a8f84dd6e50565e8d2d0bb7445acd74db0a00e1 Mon Sep 17 00:00:00 2001 From: Clay Date: Thu, 1 Feb 2024 12:33:06 -0500 Subject: [PATCH 2/5] necessary for me to plot --- interpretability/task_modeling/callbacks/callbacks.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/interpretability/task_modeling/callbacks/callbacks.py b/interpretability/task_modeling/callbacks/callbacks.py index fe18934..6c61767 100644 --- a/interpretability/task_modeling/callbacks/callbacks.py +++ b/interpretability/task_modeling/callbacks/callbacks.py @@ -1,6 +1,8 @@ import io import matplotlib.pyplot as plt +import matplotlib +matplotlib.use('Agg') import numpy as np import pytorch_lightning as pl import torch From cf4074d572a2a38c8f7c0b1b3149c4854e9e6539 Mon Sep 17 00:00:00 2001 From: Clay Date: Thu, 1 Feb 2024 12:36:52 -0500 Subject: [PATCH 3/5] creates full RUN_DIR --- examples/run_task_training.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/run_task_training.py b/examples/run_task_training.py index ef031bb..78bac4d 100644 --- a/examples/run_task_training.py +++ b/examples/run_task_training.py @@ -23,7 +23,7 @@ log = logging.getLogger(__name__) # ---------------Options--------------- -LOCAL_MODE = False # Set to True to run locally (for debugging) +LOCAL_MODE = True # Set to True to run locally (for debugging) OVERWRITE = True # Set to True to overwrite existing run RUN_DESC = "NBFF_Tutorial" # For WandB and run dir TASK = "NBFF" # Task to train on (see configs/task_env for options) @@ -81,7 +81,7 @@ def main( if RUN_DIR.exists() and OVERWRITE: shutil.rmtree(RUN_DIR) - RUN_DIR.mkdir() + RUN_DIR.mkdir(parents=True, exist_ok=True) shutil.copyfile(__file__, RUN_DIR / Path(__file__).name) tune.run( tune.with_parameters( From 73b15d16dcc3bea42c65fc0c59f4b7add0cd2966 Mon Sep 17 00:00:00 2001 From: Clay Date: Fri, 2 Feb 2024 13:01:20 -0500 Subject: [PATCH 4/5] last small changes for pr --- examples/run_task_training.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/run_task_training.py b/examples/run_task_training.py index 78bac4d..34d9c1d 100644 --- a/examples/run_task_training.py +++ b/examples/run_task_training.py @@ -23,7 +23,7 @@ log = logging.getLogger(__name__) # ---------------Options--------------- -LOCAL_MODE = True # Set to True to run locally (for debugging) +LOCAL_MODE = False # Set to True to run locally (for debugging) OVERWRITE = True # Set to True to overwrite existing run RUN_DESC = "NBFF_Tutorial" # For WandB and run dir TASK = "NBFF" # Task to train on (see configs/task_env for options) From 09a4a2c752c7952e44904ff05fd339ac32f1c985 Mon Sep 17 00:00:00 2001 From: Clay Date: Fri, 2 Feb 2024 13:02:32 -0500 Subject: [PATCH 5/5] gitignore --- .gitignore | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.gitignore b/.gitignore index 6f3fa38..0069905 100644 --- a/.gitignore +++ b/.gitignore @@ -118,7 +118,7 @@ celerybeat.pid *.sage.py # Environments -#.env +.env .venv env/ venv/