Skip to content

Commit

Permalink
test args
Browse files Browse the repository at this point in the history
Signed-off-by: Farhad Ramezanghorbani <farhadr@nvidia.com>
  • Loading branch information
farhadrgh committed Jan 14, 2025
1 parent 8ca52b4 commit c213cdb
Show file tree
Hide file tree
Showing 2 changed files with 187 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -355,7 +355,7 @@ def finetune_esm2_entrypoint():
dataset_class=args.dataset_class,
config_class=args.config_class,
overlap_grad_reduce=not args.no_overlap_grad_reduce,
overlap_param_gather=not args.overlap_param_gather,
overlap_param_gather=not args.no_overlap_param_gather,
average_in_collective=not args.no_average_in_collective,
grad_reduce_in_fp32=args.grad_reduce_in_fp32,
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@
# limitations under the License.


from pathlib import Path
from unittest.mock import patch

import pandas as pd
import pytest
from nemo.lightning import io
Expand All @@ -22,7 +25,7 @@
from bionemo.esm2.model.finetune.dataset import InMemoryPerTokenValueDataset, InMemorySingleValueDataset
from bionemo.esm2.model.finetune.finetune_regressor import ESM2FineTuneSeqConfig
from bionemo.esm2.model.finetune.finetune_token_classifier import ESM2FineTuneTokenConfig
from bionemo.esm2.scripts.finetune_esm2 import train_model as finetune
from bionemo.esm2.scripts.finetune_esm2 import finetune_esm2_entrypoint, get_parser, train_model
from bionemo.testing import megatron_parallel_state_utils
from bionemo.testing.callbacks import MetricTracker

Expand Down Expand Up @@ -50,7 +53,7 @@ def test_esm2_finetune_token_classifier(
seed: int = 42,
):
with megatron_parallel_state_utils.distributed_model_parallel_state(seed):
simple_ft_checkpoint, simple_ft_metrics, trainer = finetune(
simple_ft_checkpoint, simple_ft_metrics, trainer = train_model(
train_data_path=data_to_csv(dummy_data_per_token_classification_ft, tmp_path),
valid_data_path=data_to_csv(dummy_data_per_token_classification_ft, tmp_path),
experiment_name="finetune_new_head_token_classification",
Expand Down Expand Up @@ -95,7 +98,7 @@ def test_esm2_finetune_regressor(
seed: int = 42,
):
with megatron_parallel_state_utils.distributed_model_parallel_state(seed):
simple_ft_checkpoint, simple_ft_metrics, trainer = finetune(
simple_ft_checkpoint, simple_ft_metrics, trainer = train_model(
train_data_path=data_to_csv(dummy_data_single_value_regression_ft, tmp_path),
valid_data_path=data_to_csv(dummy_data_single_value_regression_ft, tmp_path),
experiment_name="finetune_new_head_regression",
Expand Down Expand Up @@ -130,3 +133,183 @@ def test_esm2_finetune_regressor(
p.requires_grad for name, p in trainer.model.named_parameters() if "regression_head" not in name
]
assert not all(encoder_requires_grad), "Pretrained model is not fully frozen during fine-tuning"


@pytest.fixture
def mock_train_model():
with patch("train_model") as mock_train:
yield mock_train


@pytest.fixture
def mock_parser_args():
"""Fixture to create mock arguments for the parser."""
return [
"--train-data-path",
str(Path("train.csv")),
"--valid-data-path",
str(Path("valid.csv")),
"--num-gpus",
"1",
"--num-nodes",
"1",
"--min-seq-length",
"512",
"--max-seq-length",
"1024",
"--result-dir",
str(Path("./results")),
"--lr",
"0.001",
]


def test_finetune_esm2_entrypoint(mock_train_model, mock_parser_args):
"""Test the finetune_esm2_entrypoint function with mocked arguments."""
with patch("sys.argv", ["finetune_esm2_entrypoint.py"] + mock_parser_args):
finetune_esm2_entrypoint()

# Check if train_model was called once
mock_train_model.assert_called_once()

# Check if the arguments were passed correctly
called_kwargs = mock_train_model.call_args.kwargs
assert called_kwargs["train_data_path"] == Path("train.csv")
assert called_kwargs["valid_data_path"] == Path("valid.csv")
assert called_kwargs["devices"] == 1
assert called_kwargs["num_nodes"] == 1
assert called_kwargs["min_seq_length"] == 512
assert called_kwargs["max_seq_length"] == 1024
assert called_kwargs["lr"] == 0.001
assert called_kwargs["result_dir"] == Path("./results")


def test_get_parser():
"""Test the argument parser with all possible arguments."""
parser = get_parser()
args = parser.parse_args(
[
"--train-data-path",
"train.csv",
"--valid-data-path",
"valid.csv",
"--precision",
"bf16-mixed",
"--lr",
"0.001",
"--create-tensorboard-logger",
"--resume-if-exists",
"--result-dir",
"./results",
"--experiment-name",
"esm2_experiment",
"--wandb-entity",
"my_team",
"--wandb-project",
"geneformer_project",
"--wandb-tags",
"tag1",
"tag2",
"--wandb-group",
"group1",
"--wandb-id",
"1234",
"--wandb-anonymous",
"--wandb-log-model",
"--wandb-offline",
"--num-gpus",
"2",
"--num-nodes",
"1",
"--num-steps",
"1000",
"--num-dataset-workers",
"4",
"--val-check-interval",
"500",
"--log-every-n-steps",
"100",
"--min-seq-length",
"512",
"--max-seq-length",
"1024",
"--limit-val-batches",
"2",
"--micro-batch-size",
"32",
"--pipeline-model-parallel-size",
"2",
"--tensor-model-parallel-size",
"2",
"--accumulate-grad-batches",
"2",
"--save-last-checkpoint",
"--metric-to-monitor-for-checkpoints",
"val_loss",
"--save-top-k",
"5",
"--restore-from-checkpoint-path",
"./checkpoint",
"--nsys-profiling",
"--nsys-start-step",
"10",
"--nsys-end-step",
"50",
"--nsys-ranks",
"0",
"1",
"--no-overlap-grad-reduce",
"--no-overlap-param-gather",
"--no-average-in-collective",
"--grad-reduce-in-fp32",
"--dataset-class",
"InMemorySingleValueDataset",
"--config-class",
"ESM2FineTuneSeqConfig",
]
)

# Assertions for all arguments
assert args.train_data_path == Path("train.csv")
assert args.valid_data_path == Path("valid.csv")
assert args.precision == "bf16-mixed"
assert args.lr == 0.001
assert args.create_tensorboard_logger is True
assert args.resume_if_exists is True
assert args.result_dir == Path("./results")
assert args.experiment_name == "esm2_experiment"
assert args.wandb_entity == "my_team"
assert args.wandb_project == "geneformer_project"
assert args.wandb_tags == ["tag1", "tag2"]
assert args.wandb_group == "group1"
assert args.wandb_id == "1234"
assert args.wandb_anonymous is True
assert args.wandb_log_model is True
assert args.wandb_offline is True
assert args.num_gpus == 2
assert args.num_nodes == 1
assert args.num_steps == 1000
assert args.num_dataset_workers == 4
assert args.val_check_interval == 500
assert args.log_every_n_steps == 100
assert args.min_seq_length == 512
assert args.max_seq_length == 1024
assert args.limit_val_batches == 2
assert args.micro_batch_size == 32
assert args.pipeline_model_parallel_size == 2
assert args.tensor_model_parallel_size == 2
assert args.accumulate_grad_batches == 2
assert args.save_last_checkpoint is True
assert args.metric_to_monitor_for_checkpoints == "val_loss"
assert args.save_top_k == 5
assert args.restore_from_checkpoint_path == Path("./checkpoint")
assert args.nsys_profiling is True
assert args.nsys_start_step == 10
assert args.nsys_end_step == 50
assert args.nsys_ranks == [0, 1]
assert args.no_overlap_grad_reduce is True
assert args.no_overlap_param_gather is True
assert args.no_average_in_collective is True
assert args.grad_reduce_in_fp32 is True
assert args.dataset_class == "InMemorySingleValueDataset"
assert args.config_class == "ESM2FineTuneSeqConfig"

0 comments on commit c213cdb

Please sign in to comment.