From 5eddee17d755bd5cbde6d23e3f81403064c85297 Mon Sep 17 00:00:00 2001 From: Dorota Toczydlowska <115542912+dorotat-nv@users.noreply.github.com> Date: Tue, 7 Jan 2025 16:55:50 +0100 Subject: [PATCH] add initial configs for perf testing on ESM2 in JET (bionemo2) (#497) Adding ESM2 650M conv + perf partial configs for benchmarking with required changes --- ci/benchmarks/partial-conv/esm2_pretrain.yaml | 54 +++++++++++++++ ci/benchmarks/perf/esm2_pretrain.yaml | 65 +++++++++++++++++++ .../bionemo-esm2/src/bionemo/esm2/run/main.py | 21 +++++- .../src/bionemo/esm2/scripts/train_esm2.py | 40 +++++++++--- .../bionemo/esm2/scripts/test_train_esm2.py | 26 ++++++-- .../src/bionemo/geneformer/run/main.py | 21 +++++- .../geneformer/scripts/train_geneformer.py | 40 +++++++++--- .../scripts/test_train_geneformer.py | 60 ++++++++++++++++- .../src/bionemo/llm/run/config_models.py | 10 +++ .../bionemo-llm/src/bionemo/llm/train.py | 20 +++--- .../src/bionemo/llm/utils/logger_utils.py | 6 +- 11 files changed, 326 insertions(+), 37 deletions(-) create mode 100644 ci/benchmarks/partial-conv/esm2_pretrain.yaml create mode 100644 ci/benchmarks/perf/esm2_pretrain.yaml diff --git a/ci/benchmarks/partial-conv/esm2_pretrain.yaml b/ci/benchmarks/partial-conv/esm2_pretrain.yaml new file mode 100644 index 0000000000..fd232d922e --- /dev/null +++ b/ci/benchmarks/partial-conv/esm2_pretrain.yaml @@ -0,0 +1,54 @@ +scope: partial-conv +time_limit: 14400 +script_args: + # All arguments referenced in the script string must be specified here. + # Arguments not referenced in the script string must have the 'arg' field specified. + # See jet/core/configs.py for the specification of the configuration class + workspace: + value: /workspace/bionemo2 + key_segment: False + data_path: + value: /data/20240809_uniref_2024_03/data + key_segment: False + model: + value: esm2 + variant: + value: train + config_name: + value: 650M + precision: + value: [bf16-mixed] + nodes: + value: [4] + gpus: + value: 8 + batch_size: + value: 16 + max_steps: + value: 26500 +script: |- + WANDB_API_KEY=$BIONEMO_WANDB_API_KEY ${variant}_${model} \ + --train-cluster-path=${data_path}/train_clusters.parquet \ + --train-database-path=${data_path}/train.db \ + --valid-cluster-path=${data_path}/valid_clusters.parquet \ + --valid-database-path=${data_path}/validation.db \ + --micro-batch-size=${batch_size} \ + --num-nodes=${nodes} \ + --num-gpus=${gpus} \ + --val-check-interval=1000 \ + --limit-val-batches=1 \ + --num-steps=${max_steps} \ + --min-seq-length=1024 \ + --max-seq-length=1024 \ + --num-layers=33 \ + --hidden-size=1280 \ + --num-attention-heads=20 \ + --ffn-hidden-size=5120 \ + --create-tensorboard-logger \ + --experiment-name=${batch_size}bs_${nodes}node_${gpus}gpu_${max_steps}s_${precision}prec \ + --result-dir=${tensorboard_dir} \ + --wandb-project=${wandb_project_name} \ + --wandb-group=${model}_${variant}_${config_name} \ + --wandb-job-type=${pipeline_label} \ + --log-every-n-steps=50 \ + --disable-checkpointing; diff --git a/ci/benchmarks/perf/esm2_pretrain.yaml b/ci/benchmarks/perf/esm2_pretrain.yaml new file mode 100644 index 0000000000..c45ceeb24c --- /dev/null +++ b/ci/benchmarks/perf/esm2_pretrain.yaml @@ -0,0 +1,65 @@ +scope: perf +time_limit: 1800 +script_args: + # All arguments referenced in the script string must be specified here. + # Arguments not referenced in the script string must have the 'arg' field specified. + # See jet/core/configs.py for the specification of the configuration class + workspace: + value: /workspace/bionemo2 + key_segment: False + data_path: + value: /data/20240809_uniref_2024_03/data + key_segment: False + model: esm2 + variant: train + config_name: 650M + precision: bf16-mixed + max_steps: 200 + gpus: 8 + acc_grad: 1 + products: + - nodes: 1 + batch_size: 16 + pp: 1 + tp: 1 + - nodes: 2 + batch_size: 16 + pp: 2 + tp: 1 + - nodes: 2 + batch_size: 16 + pp: 1 + tp: 2 + - nodes: 2 + batch_size: 16 + pp: 1 + tp: 1 +script: |- + WANDB_API_KEY=$BIONEMO_WANDB_API_KEY ${variant}_${model} \ + --train-cluster-path=${data_path}/train_clusters.parquet \ + --train-database-path=${data_path}/train.db \ + --valid-cluster-path=${data_path}/valid_clusters.parquet \ + --valid-database-path=${data_path}/validation.db \ + --micro-batch-size=${batch_size} \ + --num-nodes=${nodes} \ + --num-gpus=${gpus} \ + --val-check-interval=50 \ + --limit-val-batches=1 \ + --num-steps=${max_steps} \ + --min-seq-length=1024 \ + --max-seq-length=1024 \ + --num-layers=33 \ + --hidden-size=1280 \ + --num-attention-heads=20 \ + --ffn-hidden-size=5120 \ + --create-tensorboard-logger \ + --experiment-name=${batch_size}bs_${nodes}node_${gpus}gpu_${max_steps}s_${precision}prec \ + --result-dir=${tensorboard_dir} \ + --wandb-project=${wandb_project_name} \ + --wandb-group=${model}_${variant}_${config_name} \ + --wandb-job-type=${pipeline_label} \ + --log-every-n-steps=10 \ + --accumulate-grad-batches=${acc_grad} \ + --pipeline-model-parallel-size=${pp} \ + --tensor-model-parallel-size={tp} \ + --disable-checkpointing; diff --git a/sub-packages/bionemo-esm2/src/bionemo/esm2/run/main.py b/sub-packages/bionemo-esm2/src/bionemo/esm2/run/main.py index 857db8ad48..d67f715bea 100644 --- a/sub-packages/bionemo-esm2/src/bionemo/esm2/run/main.py +++ b/sub-packages/bionemo-esm2/src/bionemo/esm2/run/main.py @@ -78,6 +78,13 @@ def parse_args(): default=[0], help="Enable nsys profiling for these ranks.", ) + parser.add_argument( + "--disable-checkpointing", + action="store_false", + default=True, + dest="create_checkpoint_callback", + help="Disable creating a ModelCheckpoint callback.", + ) return parser.parse_args() def string_to_class(path: str): @@ -87,7 +94,12 @@ def string_to_class(path: str): module = importlib.import_module(module_path) return getattr(module, class_name) - def load_config(config_path: str, model_config_cls: Optional[str], data_config_cls: Optional[str]) -> MainConfig: + def load_config( + config_path: str, + model_config_cls: Optional[str], + data_config_cls: Optional[str], + create_checkpoint_callback: bool, + ) -> MainConfig: with open(config_path, "r") as f: config_dict = yaml.safe_load(f) @@ -109,10 +121,15 @@ def load_config(config_path: str, model_config_cls: Optional[str], data_config_c elif isinstance(data_config_cls, str): data_config_cls = string_to_class(data_config_cls) + # disable checkpointing if called from the command line + if not create_checkpoint_callback: + config_dict["training_config"]["enable_checkpointing"] = create_checkpoint_callback + config_dict["experiment_config"]["create_checkpoint_callback"] = create_checkpoint_callback + return MainConfig[model_config_cls, data_config_cls](**config_dict) args = parse_args() - config = load_config(args.config, args.model_config_cls, args.data_config_cls) + config = load_config(args.config, args.model_config_cls, args.data_config_cls, args.create_checkpoint_callback) if args.nsys_profiling: nsys_config = NsysConfig( diff --git a/sub-packages/bionemo-esm2/src/bionemo/esm2/scripts/train_esm2.py b/sub-packages/bionemo-esm2/src/bionemo/esm2/scripts/train_esm2.py index 847da1dd0b..87ba5003cf 100644 --- a/sub-packages/bionemo-esm2/src/bionemo/esm2/scripts/train_esm2.py +++ b/sub-packages/bionemo-esm2/src/bionemo/esm2/scripts/train_esm2.py @@ -71,6 +71,7 @@ def main( wandb_offline: bool = False, wandb_tags: Optional[List[str]] = None, wandb_group: Optional[str] = None, + wandb_job_type: Optional[str] = None, wandb_id: Optional[str] = None, wandb_anonymous: Optional[bool] = False, wandb_log_model: bool = False, @@ -78,6 +79,7 @@ def main( tensor_model_parallel_size: int = 1, create_tensorboard_logger: bool = False, nemo1_init_path: Optional[Path] = None, + create_checkpoint_callback: bool = True, restore_from_checkpoint_path: Optional[str] = None, save_best_checkpoint: bool = True, save_last_checkpoint: bool = True, @@ -129,6 +131,7 @@ def main( wandb_offline (bool): Run offline (data can be streamed later to wandb servers). wandb_tags (Optional[List[str]]): Tags associated with this run wandb_group (Optional[str]): A unique string shared by all runs in a given group + wandb_job_type (Optional[str]): Type of run, which is useful when you're grouping runs together into larger experiments using group. wandb_id (Optional[str]): Sets the version, mainly used to resume a previous run wandb_anonymous (Optional[bool]): Enables or explicitly disables anonymous logging wandb_log_model (bool): Save checkpoints in wandb dir to upload on W&B servers @@ -136,6 +139,7 @@ def main( tensor_model_parallel_size (int): tensor model parallel size create_tensorboard_logger (bool): create the tensorboard logger nemo1_init_path (Optional[Path]): Nemo 1 initialization path + create_checkpoint_callback (bool): create a ModelCheckpoint callback and attach it to the pytorch lightning trainer restore_from_checkpoint_path (Optional[str]): If set, restores the model from the directory passed in. Expects the checkpoint to be created by using the ModelCheckpoint class and always_save_context=True. save_best_checkpoint (bool): whether to save the best checkpoint @@ -199,6 +203,7 @@ def main( entity=wandb_entity, tags=wandb_tags, group=wandb_group, + job_type=wandb_job_type, id=wandb_id, anonymous=wandb_anonymous, log_model=wandb_log_model, @@ -237,6 +242,7 @@ def main( grad_reduce_in_fp32=grad_reduce_in_fp32, autocast_enabled=False, ), + enable_checkpointing=create_checkpoint_callback, ) tokenizer = get_tokenizer() @@ -298,14 +304,17 @@ def main( ) # Configure our custom Checkpointer - checkpoint_callback = nl_callbacks.ModelCheckpoint( - save_last=save_last_checkpoint, - monitor=metric_to_monitor_for_checkpoints, # "val_loss", - save_top_k=save_top_k, - every_n_train_steps=val_check_interval, - always_save_context=True, # Enables the .nemo file-like checkpointing where all IOMixins are under SerDe - filename="{epoch}-{val_loss:.2f}-{step}-{consumed_samples}", # Including step and consumed_samples in the checkpoint filename prevents duplicate filenames and bugs related to this. - ) + if create_checkpoint_callback: + checkpoint_callback = nl_callbacks.ModelCheckpoint( + save_last=save_last_checkpoint, + monitor=metric_to_monitor_for_checkpoints, # "val_loss", + save_top_k=save_top_k, + every_n_train_steps=val_check_interval, + always_save_context=True, # Enables the .nemo file-like checkpointing where all IOMixins are under SerDe + filename="{epoch}-{val_loss:.2f}-{step}-{consumed_samples}", # Including step and consumed_samples in the checkpoint filename prevents duplicate filenames and bugs related to this. + ) + else: + checkpoint_callback = None # Setup the logger and train the model nemo_logger = setup_nemo_lightning_logger( @@ -348,6 +357,7 @@ def train_esm2_entrypoint(): wandb_project=args.wandb_project, wandb_tags=args.wandb_tags, wandb_group=args.wandb_group, + wandb_job_type=args.wandb_job_type, wandb_id=args.wandb_id, wandb_anonymous=args.wandb_anonymous, wandb_log_model=args.wandb_log_model, @@ -369,6 +379,7 @@ def train_esm2_entrypoint(): experiment_name=args.experiment_name, resume_if_exists=args.resume_if_exists, nemo1_init_path=args.nemo1_init_path, + create_checkpoint_callback=args.create_checkpoint_callback, restore_from_checkpoint_path=args.restore_from_checkpoint_path, save_best_checkpoint=args.save_best_checkpoint, save_last_checkpoint=args.save_last_checkpoint, @@ -459,6 +470,12 @@ def get_parser(): parser.add_argument( "--wandb-group", type=str, default=None, help="A unique string shared by all runs in a given group" ) + parser.add_argument( + "--wandb-job-type", + type=str, + default=None, + help="A unique string representing a type of run, which is useful when you're grouping runs together into larger experiments using group.", + ) parser.add_argument( "--wandb-id", type=str, default=None, help="Sets the version, mainly used to resume a previous run" ) @@ -580,6 +597,13 @@ def get_parser(): required=False, help="Path to nemo1 file, if desired to load at init time.", ) + parser.add_argument( + "--disable-checkpointing", + action="store_false", + default=True, + dest="create_checkpoint_callback", + help="Disable creating a ModelCheckpoint callback.", + ) parser.add_argument( "--save-best-checkpoint", action="store_true", diff --git a/sub-packages/bionemo-esm2/tests/bionemo/esm2/scripts/test_train_esm2.py b/sub-packages/bionemo-esm2/tests/bionemo/esm2/scripts/test_train_esm2.py index ab15ae0b4b..ab9040d395 100644 --- a/sub-packages/bionemo-esm2/tests/bionemo/esm2/scripts/test_train_esm2.py +++ b/sub-packages/bionemo-esm2/tests/bionemo/esm2/scripts/test_train_esm2.py @@ -82,7 +82,10 @@ def dummy_parquet_train_val_inputs(tmp_path): return train_cluster_path, valid_cluster_path -def test_main_runs(monkeypatch, tmpdir, dummy_protein_dataset, dummy_parquet_train_val_inputs): +@pytest.mark.parametrize("create_checkpoint_callback", [True, False]) +def test_main_runs( + monkeypatch, tmpdir, dummy_protein_dataset, dummy_parquet_train_val_inputs, create_checkpoint_callback +): train_cluster_path, valid_cluster_path = dummy_parquet_train_val_inputs result_dir = Path(tmpdir.mkdir("results")) @@ -119,6 +122,7 @@ def test_main_runs(monkeypatch, tmpdir, dummy_protein_dataset, dummy_parquet_tra num_attention_heads=2, hidden_size=4, ffn_hidden_size=4 * 4, + create_checkpoint_callback=create_checkpoint_callback, ) assert (result_dir / "test_experiment").exists(), "Could not find test experiment directory." @@ -126,12 +130,20 @@ def test_main_runs(monkeypatch, tmpdir, dummy_protein_dataset, dummy_parquet_tra children = list((result_dir / "test_experiment").iterdir()) assert len(children) == 1, f"Expected 1 child in test experiment directory, found {children}." uq_rundir = children[0] # it will be some date. - assert ( - result_dir / "test_experiment" / uq_rundir / "checkpoints" - ).exists(), "Could not find test experiment checkpoints directory." - assert ( - result_dir / "test_experiment" / uq_rundir / "checkpoints" - ).is_dir(), "Test experiment checkpoints directory is supposed to be a directory." + + # checking directory with checkpoints + expected_exists = create_checkpoint_callback + actual_exists = (result_dir / "test_experiment" / uq_rundir / "checkpoints").exists() + assert expected_exists == actual_exists, ( + f"Checkpoints directory existence mismatch. " + f"Expected: {'exists' if expected_exists else 'does not exist'}, " + f"Found: {'exists' if actual_exists else 'does not exist'}." + ) + + if create_checkpoint_callback: + assert ( + result_dir / "test_experiment" / uq_rundir / "checkpoints" + ).is_dir(), "Test experiment checkpoints directory is supposed to be a directory." assert ( result_dir / "test_experiment" / uq_rundir / "nemo_log_globalrank-0_localrank-0.txt" ).is_file(), "Could not find experiment log." diff --git a/sub-packages/bionemo-geneformer/src/bionemo/geneformer/run/main.py b/sub-packages/bionemo-geneformer/src/bionemo/geneformer/run/main.py index 4b49946cef..377803d95c 100644 --- a/sub-packages/bionemo-geneformer/src/bionemo/geneformer/run/main.py +++ b/sub-packages/bionemo-geneformer/src/bionemo/geneformer/run/main.py @@ -82,6 +82,13 @@ def parse_args(): default=[0], help="Enable nsys profiling for these ranks.", ) + parser.add_argument( + "--disable-checkpointing", + action="store_false", + default=True, + dest="create_checkpoint_callback", + help="Disable creating a ModelCheckpoint callback.", + ) return parser.parse_args() @@ -92,7 +99,12 @@ def string_to_class(path: str): module = importlib.import_module(module_path) return getattr(module, class_name) - def load_config(config_path: str, model_config_cls: Optional[str], data_config_cls: Optional[str]) -> MainConfig: + def load_config( + config_path: str, + model_config_cls: Optional[str], + data_config_cls: Optional[str], + create_checkpoint_callback: bool, + ) -> MainConfig: with open(config_path, "r") as f: config_dict = yaml.safe_load(f) @@ -106,6 +118,11 @@ def load_config(config_path: str, model_config_cls: Optional[str], data_config_c # We assume we get a string to some importable config... e.g. in the sub-package jensen, 'bionemo.jensen.configs.MyConfig' model_config_cls = string_to_class(model_config_cls) + # disable checkpointing if called from the command line + if not create_checkpoint_callback: + config_dict["training_config"]["enable_checkpointing"] = create_checkpoint_callback + config_dict["experiment_config"]["create_checkpoint_callback"] = create_checkpoint_callback + if data_config_cls is None: data_config_cls = GeneformerPretrainingDataConfig elif isinstance(data_config_cls, str): @@ -113,7 +130,7 @@ def load_config(config_path: str, model_config_cls: Optional[str], data_config_c return MainConfig[model_config_cls, data_config_cls](**config_dict) args = parse_args() - config = load_config(args.config, args.model_config_cls, args.data_config_cls) + config = load_config(args.config, args.model_config_cls, args.data_config_cls, args.create_checkpoint_callback) if args.nsys_profiling: nsys_config = NsysConfig( diff --git a/sub-packages/bionemo-geneformer/src/bionemo/geneformer/scripts/train_geneformer.py b/sub-packages/bionemo-geneformer/src/bionemo/geneformer/scripts/train_geneformer.py index 34b34e59e3..f3e5fa2bd3 100644 --- a/sub-packages/bionemo-geneformer/src/bionemo/geneformer/scripts/train_geneformer.py +++ b/sub-packages/bionemo-geneformer/src/bionemo/geneformer/scripts/train_geneformer.py @@ -75,11 +75,13 @@ def main( wandb_offline: bool = False, wandb_tags: List[str] | None = None, wandb_group: Optional[str] = None, + wandb_job_type: Optional[str] = None, wandb_id: Optional[str] = None, wandb_anonymous: bool = False, wandb_log_model: bool = False, create_tensorboard_logger: bool = False, nemo1_init_path: Path | None = None, + create_checkpoint_callback: bool = True, restore_from_checkpoint_path: Path | None = None, num_layers: int = 6, hidden_size: int = 256, @@ -132,11 +134,13 @@ def main( wandb_project (str): The name of the project to which this run will belong. wandb_tags (List[str]): Tags associated with this run. wandb_group (str): A unique string shared by all runs in a given group + wandb_job_type (Optional[str]): Type of run, which is useful when you're grouping runs together into larger experiments using group. wandb_offline (bool): Run offline (data can be streamed later to wandb servers). wandb_id (str): Sets the version, mainly used to resume a previous run. wandb_anonymous (bool): Enables or explicitly disables anonymous logging. wandb_log_model (bool): Save checkpoints in wandb dir to upload on W&B servers. create_tensorboard_logger (bool): create the tensorboard logger + create_checkpoint_callback (bool): create a ModelCheckpoint callback and attach it to the pytorch lightning trainer restore_from_checkpoint_path (path): If set, restores the model from the directory passed in. Expects the checkpoint to be created by using the ModelCheckpoint class and always_save_context=True. num_layers (int): Number of layers in geneformer. Default to 6. @@ -215,6 +219,7 @@ def main( entity=wandb_entity, tags=wandb_tags, group=wandb_group, + job_type=wandb_job_type, id=wandb_id, anonymous=wandb_anonymous, log_model=wandb_log_model, @@ -253,6 +258,7 @@ def main( callbacks=callbacks, use_distributed_sampler=False, plugins=nl.MegatronMixedPrecision(precision=precision), + enable_checkpointing=create_checkpoint_callback, ) preprocessor = GeneformerPreprocess( @@ -328,14 +334,17 @@ def main( ), ) # Configure our custom Checkpointer - checkpoint_callback = nl_callbacks.ModelCheckpoint( - save_last=save_last_checkpoint, - monitor=metric_to_monitor_for_checkpoints, - save_top_k=save_top_k, - every_n_train_steps=val_check_interval, - always_save_context=True, # Enables the .nemo file-like checkpointing where all IOMixins are under SerDe - filename="{epoch}-{val_loss:.2f}-{step}-{consumed_samples}", # Including step and consumed_samples in the checkpoint filename prevents duplicate filenames and bugs related to this. - ) + if create_checkpoint_callback: + checkpoint_callback = nl_callbacks.ModelCheckpoint( + save_last=save_last_checkpoint, + monitor=metric_to_monitor_for_checkpoints, + save_top_k=save_top_k, + every_n_train_steps=val_check_interval, + always_save_context=True, # Enables the .nemo file-like checkpointing where all IOMixins are under SerDe + filename="{epoch}-{val_loss:.2f}-{step}-{consumed_samples}", # Including step and consumed_samples in the checkpoint filename prevents duplicate filenames and bugs related to this. + ) + else: + checkpoint_callback = None # Setup the logger and train the model nemo_logger = setup_nemo_lightning_logger( @@ -409,6 +418,12 @@ def get_parser(): parser.add_argument( "--wandb-group", type=str, default=None, help="A unique string shared by all runs in a given group" ) + parser.add_argument( + "--wandb-job-type", + type=str, + default=None, + help="A unique string representing a type of run, which is useful when you're grouping runs together into larger experiments using group.", + ) parser.add_argument( "--wandb-id", type=str, default=None, help="Sets the version, mainly used to resume a previous run" ) @@ -523,6 +538,13 @@ def get_parser(): required=False, help="Path to nemo1 file, if desired to load at init time.", ) + parser.add_argument( + "--disable-checkpointing", + action="store_false", + default=True, + dest="create_checkpoint_callback", + help="Disable creating a ModelCheckpoint callback.", + ) parser.add_argument( "--save-best-checkpoint", action="store_true", @@ -662,6 +684,7 @@ def entrypoint(): wandb_project=args.wandb_project, wandb_tags=args.wandb_tags, wandb_group=args.wandb_group, + wandb_job_type=args.wandb_job_type, wandb_id=args.wandb_id, wandb_anonymous=args.wandb_anonymous, wandb_log_model=args.wandb_log_model, @@ -684,6 +707,7 @@ def entrypoint(): nsys_start_step=args.nsys_start_step, nsys_end_step=args.nsys_end_step, nsys_ranks=args.nsys_ranks, + create_checkpoint_callback=args.create_checkpoint_callback, restore_from_checkpoint_path=args.restore_from_checkpoint_path, config_class=args.training_model_config_class, save_last_checkpoint=args.save_last_checkpoint, diff --git a/sub-packages/bionemo-geneformer/tests/bionemo/geneformer/scripts/test_train_geneformer.py b/sub-packages/bionemo-geneformer/tests/bionemo/geneformer/scripts/test_train_geneformer.py index f471b77ea4..6ca0995971 100644 --- a/sub-packages/bionemo-geneformer/tests/bionemo/geneformer/scripts/test_train_geneformer.py +++ b/sub-packages/bionemo-geneformer/tests/bionemo/geneformer/scripts/test_train_geneformer.py @@ -31,7 +31,7 @@ @pytest.fixture def data_path() -> Path: - """Gets the path to the directory with with cellx small dataset in Single Cell Memmap format. + """Gets the path to the directory with cellx small dataset in Single Cell Memmap format. Returns: A Path object that is the directory with the specified test data. """ @@ -43,6 +43,64 @@ def test_bionemo2_rootdir(data_path): assert data_path.is_dir(), "Test data directory is supposed to be a directory." +@pytest.mark.parametrize("create_checkpoint_callback", [True, False]) +def test_main_runs(tmpdir, create_checkpoint_callback: bool, data_path: Path): + result_dir = Path(tmpdir.mkdir("results")) + + with megatron_parallel_state_utils.distributed_model_parallel_state(): + main( + data_dir=data_path, + num_nodes=1, + devices=1, + seq_length=128, + result_dir=result_dir, + wandb_project=None, + wandb_offline=True, + num_steps=5, + limit_val_batches=1, + val_check_interval=2, + num_dataset_workers=0, + biobert_spec_option=BiobertSpecOption.bert_layer_local_spec, + lr=1e-4, + micro_batch_size=2, + accumulate_grad_batches=2, + cosine_rampup_frac=0.01, + cosine_hold_frac=0.01, + precision="bf16-mixed", + experiment_name="test_experiment", + resume_if_exists=False, + create_tensorboard_logger=False, + num_layers=2, + num_attention_heads=2, + hidden_size=4, + ffn_hidden_size=4 * 2, + create_checkpoint_callback=create_checkpoint_callback, + ) + + assert (result_dir / "test_experiment").exists(), "Could not find test experiment directory." + assert (result_dir / "test_experiment").is_dir(), "Test experiment directory is supposed to be a directory." + children = list((result_dir / "test_experiment").iterdir()) + assert len(children) == 1, f"Expected 1 child in test experiment directory, found {children}." + uq_rundir = children[0] # it will be some date. + + expected_exists = create_checkpoint_callback + actual_exists = (result_dir / "test_experiment" / uq_rundir / "checkpoints").exists() + + assert expected_exists == actual_exists, ( + f"Checkpoints directory existence mismatch. " + f"Expected: {'exists' if expected_exists else 'does not exist'}, " + f"Found: {'exists' if actual_exists else 'does not exist'}." + ) + + if create_checkpoint_callback: + assert ( + result_dir / "test_experiment" / uq_rundir / "checkpoints" + ).is_dir(), "Test experiment checkpoints directory is supposed to be a directory." + assert ( + result_dir / "test_experiment" / uq_rundir / "nemo_log_globalrank-0_localrank-0.txt" + ).is_file(), "Could not find experiment log." + + @pytest.mark.parametrize("limit_val_batches", [0.0, 1]) def test_val_dataloader_in_main_runs_with_limit_val_batches(tmpdir, data_path, limit_val_batches: float): result_dir = Path(tmpdir.mkdir("results")) diff --git a/sub-packages/bionemo-llm/src/bionemo/llm/run/config_models.py b/sub-packages/bionemo-llm/src/bionemo/llm/run/config_models.py index e6c0f6177d..a2065a2a02 100644 --- a/sub-packages/bionemo-llm/src/bionemo/llm/run/config_models.py +++ b/sub-packages/bionemo-llm/src/bionemo/llm/run/config_models.py @@ -301,6 +301,7 @@ class TrainingConfig(BaseModel): accelerator (str, optional): The type of accelerator to use for training. Defaults to "gpu". gc_interval (int, optional): The interval of global steps at which to run synchronized garbage collection. Useful for synchronizing garbage collection when performing distributed training. Defaults to 0. include_perplexity (bool, optional): Whether to include perplexity in the validation logs. Defaults to False. + enable_checkpointing (bool, optional): Whether to enable checkpointing and configure a default ModelCheckpoint callback if there is no user-defined ModelCheckpoint. Corresponds to the same parameter name in pl.Trainer """ max_steps: int @@ -311,6 +312,7 @@ class TrainingConfig(BaseModel): # NOTE: VERY important for distributed training performance. gc_interval: int = 0 include_perplexity: bool = False + enable_checkpointing: bool = True class OptimizerSchedulerConfig(BaseModel): @@ -351,6 +353,7 @@ class ExperimentConfig(BaseModel): metric_to_monitor_for_checkpoints (str): Metric to monitor for saving top-k checkpoints. Default is "reduced_train_loss". save_top_k (int): Number of top checkpoints to save based on the monitored metric. Default is 2. create_tensorboard_logger (bool): Flag to create a TensorBoard logger. Default is False. + create_checkpoint_callback (bool): Flag to create a ModelCheckpoint callback """ save_every_n_steps: int @@ -362,6 +365,7 @@ class ExperimentConfig(BaseModel): metric_to_monitor_for_checkpoints: str = "reduced_train_loss" save_top_k: int = 2 create_tensorboard_logger: bool = False + create_checkpoint_callback: bool = True @field_serializer("result_dir") def serialize_paths(self, value: pathlib.Path) -> str: # noqa: D102 @@ -425,3 +429,9 @@ def run_bionemo_model_config_model_validators(self) -> "MainConfig": def run_data_config_model_validators(self) -> "MainConfig": """Runs the model validators on the data_config.""" return self.data_config.custom_model_validator(self) + + @model_validator(mode="after") + def validate_checkpointing_setting(self) -> "MainConfig": + """Validates the master configuration object.""" + self.training_config.enable_checkpointing = self.experiment_config.create_checkpoint_callback + return self diff --git a/sub-packages/bionemo-llm/src/bionemo/llm/train.py b/sub-packages/bionemo-llm/src/bionemo/llm/train.py index b69e25f8db..094b763af6 100644 --- a/sub-packages/bionemo-llm/src/bionemo/llm/train.py +++ b/sub-packages/bionemo-llm/src/bionemo/llm/train.py @@ -67,14 +67,17 @@ def nemo_logger_factory(experiment_config: ExperimentConfig, wandb_config: Optio Returns: nl.NeMoLogger: An instance of NeMoLogger configured with the specified settings. """ - checkpoint_callback = nl_callbacks.ModelCheckpoint( - save_last=experiment_config.save_last_checkpoint, - monitor=experiment_config.metric_to_monitor_for_checkpoints, - save_top_k=experiment_config.save_top_k, - every_n_train_steps=experiment_config.save_every_n_steps, - always_save_context=True, - filename="{epoch}-{val_loss:.2f}-{step}-{consumed_samples}", # Including step and consumed_samples in the checkpoint filename prevents duplicate filenames and bugs related to this. - ) + if experiment_config.create_checkpoint_callback: + checkpoint_callback = nl_callbacks.ModelCheckpoint( + save_last=experiment_config.save_last_checkpoint, + monitor=experiment_config.metric_to_monitor_for_checkpoints, + save_top_k=experiment_config.save_top_k, + every_n_train_steps=experiment_config.save_every_n_steps, + always_save_context=True, + filename="{epoch}-{val_loss:.2f}-{step}-{consumed_samples}", # Including step and consumed_samples in the checkpoint filename prevents duplicate filenames and bugs related to this. + ) + else: + checkpoint_callback = None nemo_logger = setup_nemo_lightning_logger( root_dir=experiment_config.result_dir, @@ -167,6 +170,7 @@ def setup_trainer( grad_reduce_in_fp32=False, autocast_enabled=False, ), + enable_checkpointing=training_config.enable_checkpointing, ) return trainer diff --git a/sub-packages/bionemo-llm/src/bionemo/llm/utils/logger_utils.py b/sub-packages/bionemo-llm/src/bionemo/llm/utils/logger_utils.py index 912d67bf7b..ed59a1a11b 100644 --- a/sub-packages/bionemo-llm/src/bionemo/llm/utils/logger_utils.py +++ b/sub-packages/bionemo-llm/src/bionemo/llm/utils/logger_utils.py @@ -37,6 +37,7 @@ class WandbConfig(BaseModel): project: The name of the project to which this run will belong. tags: Tags associated with this run. group: A unique string shared by all runs in a given group + job_type: Type of run, which is useful when you're grouping runs together into larger experiments. offline: Run offline (data can be streamed later to wandb servers). id: Sets the version, mainly used to resume a previous run. anonymous: Enables or explicitly disables anonymous logging. @@ -47,7 +48,10 @@ class WandbConfig(BaseModel): # name: #Display name for the run. "This is handled by NeMoLogger" # save_dir: #Path where data is saved. "This is handled by NeMoLogger" tags: List[str] | None # Tags associated with this run. - group: str | None # A unique string shared by all runs in a given group + group: str | None # A unique string shared by all runs in a given group. + job_type: str | None = ( + None # Type of run, which is useful when you're grouping runs together into larger experiments. + ) offline: bool # Run offline (data can be streamed later to wandb servers). id: str | None # Sets the version, mainly used to resume a previous run. anonymous: bool # Enables or explicitly disables anonymous logging.