diff --git a/docs/docs/user-guide/examples/bionemo-esm2/finetune.md b/docs/docs/user-guide/examples/bionemo-esm2/finetune.md new file mode 100644 index 0000000000..7968a22397 --- /dev/null +++ b/docs/docs/user-guide/examples/bionemo-esm2/finetune.md @@ -0,0 +1,263 @@ +# ESM-2 Fine-Tuning + +This readme serves as a demo for implementing ESM-2 Fine-tuning module, running a regression example and using the model for inference. + +The ESM-2 model is a transformer-based protein language model that has achieved state-of-the-art results in various protein-related tasks. When fine-tuning ESM2, the task head plays a crucial role. A task head refers to the additional layer or set of layers added on top of a pre-trained model, like the ESM-2 transformer-based protein language model, to adapt it for a specific downstream task. As a part of transfer learning, a pre-trained model is often utilized to learn generic features from a large-scale dataset. However, these features might not be directly applicable to the specific task at hand. By incorporating a task head, which consists of learnable parameters, the model can adapt and specialize to the target task. The task head serves as a flexible and adaptable component that learns task-specific representations by leveraging the pre-trained features as a foundation. Through fine-tuning, the task head enables the model to learn and extract task-specific patterns, improving performance and addressing the nuances of the downstream task. It acts as a critical bridge between the pre-trained model and the specific task, enabling efficient and effective transfer of knowledge. + + +# Setup and Assumptions + +In this tutorial, we will demonstrate how to create a fine-tune module, train a regression task head, and use the fine-tuned model for inference. + +All commands should be executed inside the BioNeMo docker container, which has all ESM-2 dependencies pre-installed. This tutorial assumes that a copy of the BioNeMo framework repo exists on workstation or server and has been mounted inside the container at `/workspace/bionemo2`. (**Note**: This `WORKDIR` may be `/workspaces/bionemo-framework` if you are using the VSCode Dev Container.) For more information on how to build or pull the BioNeMo2 container, refer to the [Access and Startup](../../getting-started/access-startup.md). + +To successfully accomplish this we need to define some key classes: + +1. Loss Reduction Method - To compute the supervised fine-tuning loss. +2. Fine-Tuned Model Head - Downstream task head model. +3. Fine-Tuned Model - Model that combines ESM-2 with the task head model. +4. Fine-Tuning Config - Configures the fine-tuning model and loss to use in the training and inference framework. +5. Dataset - Training and inference datasets for ESM2. + +## 1 - Loss Reduction Class + +A class for calculating the supervised loss of the fine-tune model from targets. We inherit from Megatron Bert Masked Language Model Loss (`BERTMLMLossWithReduction`) and override the `forward()` pass to compute MSE loss of the regression head within a micro-batch. The `reduce()` method is used for computing the average over the micro-batches and is only used for logging. + +```python +class RegressorLossReduction(BERTMLMLossWithReduction): + def forward( + self, batch: Dict[str, torch.Tensor], forward_out: Dict[str, torch.Tensor] + ) -> Tuple[torch.Tensor, Union[PerTokenLossDict, SameSizeLossDict]]: + + targets = batch["labels"] # [b, 1] + regression_output = forward_out + loss = torch.nn.functional.mse_loss(regression_output, targets) + return loss, {"avg": loss} + + def reduce(self, losses_reduced_per_micro_batch: Sequence[ReductionT]) -> torch.Tensor: + losses = torch.stack([loss["avg"] for loss in losses_reduced_per_micro_batch]) + return losses.mean() +``` + +## 2 - Fine-Tuned Model Head + +An MLP class for sequence-level regression. This class inherits `MegatronModule` and uses the fine-tune config (`TransformerConfig`) to configure the regression head for the fine-tuned ESM-2 model. + +```python +class MegatronMLPHead(MegatronModule): + def __init__(self, config: TransformerConfig): + super().__init__(config) + layer_sizes = [config.hidden_size, 256, 1] + self.linear_layers = torch.nn.ModuleList( + [torch.nn.Linear(i, o) for i, o in zip(layer_sizes[:-1], layer_sizes[1:])] + ) + self.act = torch.nn.ReLU() + self.dropout = torch.nn.Dropout(p=config.ft_dropout) + + def forward(self, hidden_states: torch.Tensor) -> List[torch.Tensor]: + ... +``` + +## 3 - Fine-Tuned Model + +A fine-tuned ESM-2 model class for token classification tasks. This class inherits from the `ESM2Model` class and adds the custom regression head `MegatronMLPHead` the we created in the previous step. Optionally one can freeze all or parts of the encoder by parsing through the model parameters in the model constructor. + +```python +class ESM2FineTuneSeqModel(ESM2Model): + def __init__(self, config, *args, post_process: bool = True, return_embeddings: bool = False, **kwargs): + super().__init__(config, *args, post_process=post_process, return_embeddings=True, **kwargs) + + # freeze encoder parameters + if config.encoder_frozen: + for _, param in self.named_parameters(): + param.requires_grad = False + + if post_process: + self.regression_head = MegatronMLPHead(config) + + def forward(self, *args, **kwargs,): + output = super().forward(*args, **kwargs) + ... + regression_output = self.regression_head(embeddings) + return regression_output +``` + +## 4 - Fine-Tuning Config + +A `dataclass` that configures the fine-tuned ESM-2 model. In this example `ESM2FineTuneSeqConfig` inherits from `ESM2GenericConfig` and adds custom arguments to setup the fine-tuned model. The `configure_model()` method of this `dataclass` is called within the `Lightning` module to call the model constructor with the `dataclass` arguments. + +The common arguments among different fine-tuning tasks are + +- `model_cls`: The fine-tune model class (`ESM2FineTuneSeqModel`) +- `initial_ckpt_path`: BioNeMo 2.0 ESM-2 pre-trained checkpoint +- `initial_ckpt_skip_keys_with_these_prefixes`: skip keys when loading parameters from a checkpoint. Here we should not look for `regression_head` in the pre-trained checkpoint. +- `get_loss_reduction_class()`: Implements selection of the appropriate `MegatronLossReduction` class, e.g. `bionemo.esm2.model.finetune.finetune_regressor.RegressorLossReduction`. + +```python + +@dataclass +class ESM2FineTuneSeqConfig(ESM2GenericConfig[ESM2FineTuneSeqModel], iom.IOMixinWithGettersSetters): + model_cls: Type[ESM2FineTuneSeqModel] = ESM2FineTuneSeqModel + # The following checkpoint path is for nemo2 checkpoints. Config parameters not present in + # self.override_parent_fields will be loaded from the checkpoint and override those values here. + initial_ckpt_path: str | None = None + # typical case is fine-tune the base biobert that doesn't have this head. If you are instead loading a checkpoint + # that has this new head and want to keep using these weights, please drop this next line or set to [] + initial_ckpt_skip_keys_with_these_prefixes: List[str] = field(default_factory=lambda: ["regression_head"]) + + encoder_frozen: bool = True # freeze encoder parameters + ft_dropout: float = 0.25 # MLP layer dropout + + def get_loss_reduction_class(self) -> Type[MegatronLossReduction]: + return RegressorLossReduction +``` + +## 5 - Dataset + +We will use a sample dataset for demonstration purposes. Create a dataset class by extending from ```torch.utils.data.Dataset```. For the purposes of this demo, we'll assume dataset consists of small set of protein sequences with a target value of `len(sequence) / 100.0` as their labels. + +```python +data = [ + ("MVLSPADKTNVKAAWGKVGAHAGEYGAEALERH", 0.33), + ... +] +``` + +Therefore, the custom BioNeMo dataset class will be appropriate (found in ```bionemo.esm2.model.finetune.finetune_regressor.InMemorySingleValueDataset```) as it facilitates predicting on a single value. An excerpt from the class is shown below. This example dataset expected a sequence of `Tuple` that hold `(sequence, target)` values. However, one can simply extend ```InMemorySingleValueDataset``` class in a similar way to customize your class for your data. + +```python +class InMemorySingleValueDataset(Dataset): + def __init__( + self, + data: Sequence[Tuple[str, float]], + tokenizer: tokenizer.BioNeMoESMTokenizer = tokenizer.get_tokenizer(), + seed: int = np.random.SeedSequence().entropy, + ): +``` + +For any arbitrary data file formats, user can process the data into a list of tuples containing (sequence, label) and use this dataset class. Or override the dataset class to load their custom data files. + +To coordinate the creation of training, validation and testing datasets from your data, we need to use a `datamodule` class. To do this we can directly use or extend the ```ESM2FineTuneDataModule``` class (located at ```bionemo.esm2.model.finetune.datamodule.ESM2FineTuneDataModule```) which defines helpful abstract methods that use your dataset class. + +```python +dataset = InMemorySingleValueDataset(data) +data_module = ESM2FineTuneDataModule( + train_dataset=train_dataset, + valid_dataset=valid_dataset + micro_batch_size=4, # size of a batch to be processed in a device + global_batch_size=8, # size of batch across all devices. Should be multiple of micro_batch_size +) +``` + +# Fine-Tuning the Regressor Task Head for ESM2 + +Now we can put these five requirements together to fine-tune a regressor task head starting from a pre-trained 650M ESM-2 model (`pretrain_ckpt_path`). We can take advantage of a simple training loop in ```bionemo.esm2.model.fnetune.train``` and use the ```train_model()`` function to start the fine-tuning process in the following. + +```python +# create a List[Tuple] with (sequence, target) values +artificial_sequence_data = [ + "TLILGWSDKLGSLLNQLAIANESLGGGTIAVMAERDKEDMELDIGKMEFDFKGTSVI", + "LYSGDHSTQGARFLRDLAENTGRAEYELLSLF", + "GRFNVWLGGNESKIRQVLKAVKEIGVSPTLFAVYEKN", + "DELTALGGLLHDIGKPVQRAGLYSGDHSTQGARFLRDLAENTGRAEYELLSLF", + "KLGSLLNQLAIANESLGGGTIAVMAERDKEDMELDIGKMEFDFKGTSVI", + "LFGAIGNAISAIHGQSAVEELVDAFVGGARISSAFPYSGDTYYLPKP", + "LGGLLHDIGKPVQRAGLYSGDHSTQGARFLRDLAENTGRAEYELLSLF", + "LYSGDHSTQGARFLRDLAENTGRAEYELLSLF", + "ISAIHGQSAVEELVDAFVGGARISSAFPYSGDTYYLPKP", + "SGSKASSDSQDANQCCTSCEDNAPATSYCVECSEPLCETCVEAHQRVKYTKDHTVRSTGPAKT", +] + +data = [(seq, len(seq)/100.0) for seq in artificial_sequence_data] + +# we are training and validating on the same dataset for simplicity +dataset = InMemorySingleValueDataset(data) +data_module = ESM2FineTuneDataModule(train_dataset=dataset, valid_dataset=dataset) + +experiment_name = "finetune_regressor" +n_steps_train = 50 +seed = 42 + +# To download a 650M pre-trained ESM2 model +pretrain_ckpt_path = load("esm2/650m:2.0") + +config = ESM2FineTuneSeqConfig( + initial_ckpt_path=str(pretrain_ckpt_path) +) + +checkpoint, metrics, trainer = train_model( + experiment_name=experiment_name, + experiment_dir=Path(experiment_results_dir), # new checkpoint will land in a subdir of this + config=config, # same config as before since we are just continuing training + data_module=data_module, + n_steps_train=n_steps_train, +) +``` + +This example is fully implemented in ```bionemo.esm2.model.finetune.train``` and can be executed by: +```bash +python -m bionemo.esm2.model.finetune.train +``` + +## Notes +1. The above example is fine-tuning a 650M ESM-2 model. The pre-trained checkpoints can be downloaded from NGC resources using either the following bash command or the `load` function in `bionemo.core.data.load` as shown above. + ```bash + download_bionemo_data esm2/650m:2.0 + ``` + and pass the output path (e.g. `.../.cache/bionemo/975d29ee980fcb08c97401bbdfdcf8ce-esm2_650M_nemo2.tar.gz.untar`) as an argument into `initial_ckpt_path` while setting the config object: + ```python + config = ESM2FineTuneSeqConfig( + initial_ckpt_path=str(pretrain_ckpt_path) + ) + ``` +2. Due to Megatron limitations, the log produced by the training run iterates on steps/iterations and not epochs. Therefore, `Training epoch` counter stays at value zero while `iteration` and `global_ste`p increase during the course of training (example in the following). + ```bash + Training epoch 0, iteration | ... | global_step: | reduced_train_loss: ... | val_loss: ... + ``` + to achieve the same epoch-based effect while training, please choose the number of training steps (`n_steps_train`) so that: + ```bash + n_steps_train * global_batch_size = len(dataset) * desired_num_epochs + ``` +3. We are using a small dataset of artificial sequences as our fine-tuning data in this example. You may experience over-fitting and observe no change in the validation metrics. + +# Fine-Tuned ESM-2 Model Inference +Now we can use ```bionemo.esm2.model.finetune.train.infer``` to run inference on an example prediction dataset. +Record the checkpoint path reported at the end of the finetuning run, after executing `python -m bionemo.esm2.model.finetune.train` (e.g. `/tmp/tmp1b5wlnba/finetune_regressor/checkpoints/finetune_regressor--reduced_train_loss=0.0016-epoch=0-last`) and use that as an argument to inference script (`--checkpoint-path`). + +We download a CSV example dataset of articical sequences for this inference example. Please refer to [ESM-2 Inference](./inference) tutorial for detailed explanation of the arguments and how to create your own CSV file. + +```bash +mkdir -p $WORKDIR/esm2_finetune_tutorial + +# download sample data CSV for inference +DATA_PATH=$(download_bionemo_data esm2/testdata_esm2_infer:2.0) +RESULTS_PATH=$WORKDIR/esm2_finetune_tutorial/ + +infer_esm2 --checkpoint-path \ + --data-path $DATA_PATH \ + --results-path $RESULTS_PATH \ + --config-class ESM2FineTuneSeqConfig +``` + +This will create a result `.pt` file under `$WORKDIR/esm2_finetune_tutorial/predictions__rank_0.pt` which can be loaded via PyTorch library in python environment: + +```python +import torch + +# Set the path to results file e.g. /workspace/bionemo2/esm2_finetune_tutorial/predictions__rank_0.pt +# results_path = /workspace/bionemo2/esm2_finetune_tutorial/predictions__rank_0.pt +results = torch.load(results_path) + +# results is a python dict which includes the following result tensors for this example: +# results['regression_output'] is a tensor with shape: torch.Size([10, 1]) +``` + +## Notes +- ESM2 Inference module takes the `--checkpoint-path` and `--config-class` arguments to create a config object by pointing the path in `initial_ckpt_path`. Since we need to load all the parameters from this checkpoint (and don't skip the head) we reset the `initial_ckpt_skip_keys_with_these_prefixes` in this config. + + ```python + config = ESM2FineTuneSeqConfig( + initial_ckpt_path = , + initial_ckpt_skip_keys_with_these_prefixes: List[str] = field(default_factory=list) + ) + ``` diff --git a/docs/docs/user-guide/examples/bionemo-esm2/inference.ipynb b/docs/docs/user-guide/examples/bionemo-esm2/inference.ipynb index f487e73011..5dfd17964f 100644 --- a/docs/docs/user-guide/examples/bionemo-esm2/inference.ipynb +++ b/docs/docs/user-guide/examples/bionemo-esm2/inference.ipynb @@ -141,40 +141,11 @@ "execution_count": 4, "metadata": {}, "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "Downloading data from 'nvidia/clara/esm2nv650m:2.0' to file '/home/ubuntu/.cache/bionemo/0798767e843e3d54315aef91934d28ae7d8e93c2849d5fcfbdf5fac242013997-esm2_650M_nemo2.tar.gz'.\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "{\n", - " \"download_end\": \"2025-01-14 22:01:24\",\n", - " \"download_start\": \"2025-01-14 22:01:05\",\n", - " \"download_time\": \"18s\",\n", - " \"files_downloaded\": 1,\n", - " \"local_path\": \"/home/ubuntu/.cache/bionemo/tmpfj1e52vw/esm2nv650m_v2.0\",\n", - " \"size_downloaded\": \"1.12 GB\",\n", - " \"status\": \"COMPLETED\"\n", - "}\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "Untarring contents of '/home/ubuntu/.cache/bionemo/0798767e843e3d54315aef91934d28ae7d8e93c2849d5fcfbdf5fac242013997-esm2_650M_nemo2.tar.gz' to '/home/ubuntu/.cache/bionemo/0798767e843e3d54315aef91934d28ae7d8e93c2849d5fcfbdf5fac242013997-esm2_650M_nemo2.tar.gz.untar'\n" - ] - }, { "name": "stdout", "output_type": "stream", "text": [ - "/home/ubuntu/.cache/bionemo/0798767e843e3d54315aef91934d28ae7d8e93c2849d5fcfbdf5fac242013997-esm2_650M_nemo2.tar.gz.untar\n" + "/home/bionemo/.cache/bionemo/0798767e843e3d54315aef91934d28ae7d8e93c2849d5fcfbdf5fac242013997-esm2_650M_nemo2.tar.gz.untar\n" ] } ], @@ -197,7 +168,7 @@ "metadata": {}, "source": [ "\n", - "We use the `InMemoryProteinDataset` class to load the protein sequence data from a `.csv` file. This data file should at least have a `sequences` column and can optionally have a `labels` column used for fine-tuning applications. Here is an example of how to create your own inference input data using a list of sequences in python:" + "We use the `InMemoryCSVDataset` class to load the protein sequence data from a `.csv` file. This data file should at least have a `sequences` column and can optionally have a `labels` column used for fine-tuning applications. Here is an example of how to create your own inference input data using a list of sequences in python:" ] }, { @@ -267,12 +238,12 @@ "name": "stdout", "output_type": "stream", "text": [ - "2025-01-14 22:01:45 - faiss.loader - INFO - Loading faiss with AVX512 support.\n", - "2025-01-14 22:01:45 - faiss.loader - INFO - Successfully loaded faiss with AVX512 support.\n", - "[NeMo W 2025-01-14 22:01:46 nemo_logging:405] /usr/local/lib/python3.12/dist-packages/pydub/utils.py:170: RuntimeWarning: Couldn't find ffmpeg or avconv - defaulting to ffmpeg, but may not work\n", + "2024-12-16 20:19:23 - faiss.loader - INFO - Loading faiss with AVX512 support.\n", + "2024-12-16 20:19:23 - faiss.loader - INFO - Successfully loaded faiss with AVX512 support.\n", + "[NeMo W 2024-12-16 20:19:24 nemo_logging:361] /usr/local/lib/python3.10/dist-packages/pydub/utils.py:170: RuntimeWarning: Couldn't find ffmpeg or avconv - defaulting to ffmpeg, but may not work\n", " warn(\"Couldn't find ffmpeg or avconv - defaulting to ffmpeg, but may not work\", RuntimeWarning)\n", " \n", - "[NeMo W 2025-01-14 22:01:46 nemo_logging:405] /usr/local/lib/python3.12/dist-packages/pyannote/core/notebook.py:134: MatplotlibDeprecationWarning: The get_cmap function was deprecated in Matplotlib 3.7 and will be removed in 3.11. Use ``matplotlib.colormaps[name]`` or ``matplotlib.colormaps.get_cmap()`` or ``pyplot.get_cmap()`` instead.\n", + "[NeMo W 2024-12-16 20:19:24 nemo_logging:361] /usr/local/lib/python3.10/dist-packages/pyannote/core/notebook.py:134: MatplotlibDeprecationWarning: The get_cmap function was deprecated in Matplotlib 3.7 and will be removed two minor releases later. Use ``matplotlib.colormaps[name]`` or ``matplotlib.colormaps.get_cmap(obj)`` instead.\n", " cm = get_cmap(\"Set1\")\n", " \n", "usage: infer_esm2 [-h] --checkpoint-path CHECKPOINT_PATH --data-path DATA_PATH\n", @@ -562,7 +533,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.12.3" + "version": "3.10.12" } }, "nbformat": 4, diff --git a/docs/docs/user-guide/getting-started/development.md b/docs/docs/user-guide/getting-started/development.md index ae8a0997f7..ce97a78cbe 100644 --- a/docs/docs/user-guide/getting-started/development.md +++ b/docs/docs/user-guide/getting-started/development.md @@ -136,7 +136,7 @@ of the model. The fine-tuning steps will be application-specific, but a general 6. **Run inference**: Once the model is fine-tuned, use it to make predictions on new, unseen data. For more information on fine-tuning a model, refer to the [ESM-2 Fine-tuning -Tutorial](../examples/bionemo-esm2/finetune.ipynb). +Tutorial](../examples/bionemo-esm2/finetune.md). ## Advanced Developer Documentation diff --git a/sub-packages/bionemo-esm2/pyproject.toml b/sub-packages/bionemo-esm2/pyproject.toml index 4acce854fa..aa8f7715ed 100644 --- a/sub-packages/bionemo-esm2/pyproject.toml +++ b/sub-packages/bionemo-esm2/pyproject.toml @@ -22,7 +22,6 @@ bionemo-esm2-train= "bionemo.esm2.run.main:main" bionemo-esm2-recipe= "bionemo.esm2.run.recipes:main" infer_esm2 = "bionemo.esm2.scripts.infer_esm2:infer_esm2_entrypoint" train_esm2 = "bionemo.esm2.scripts.train_esm2:train_esm2_entrypoint" -finetune_esm2 = "bionemo.esm2.scripts.finetune_esm2:finetune_esm2_entrypoint" # Make sure that the tokenizer files are included along with the python files during installation. [tool.setuptools.package-data] diff --git a/sub-packages/bionemo-esm2/src/bionemo/esm2/model/finetune/datamodule.py b/sub-packages/bionemo-esm2/src/bionemo/esm2/model/finetune/datamodule.py index 7104f64373..09526572ef 100644 --- a/sub-packages/bionemo-esm2/src/bionemo/esm2/model/finetune/datamodule.py +++ b/sub-packages/bionemo-esm2/src/bionemo/esm2/model/finetune/datamodule.py @@ -15,27 +15,119 @@ import functools -from typing import Literal, Union +import os +from typing import Literal, Sequence, Tuple, Union +import numpy as np +import pandas as pd +import torch +import torch.utils.data from lightning.pytorch.utilities.types import EVAL_DATALOADERS, TRAIN_DATALOADERS from nemo.lightning.data import WrappedDataLoader from nemo.lightning.pytorch.plugins import MegatronDataSampler from nemo.utils import logging +from torch import Tensor +from torch.utils.data import Dataset from bionemo.core.data.multi_epoch_dataset import IdentityMultiEpochDatasetWrapper, MultiEpochDatasetResampler from bionemo.esm2.data import tokenizer -from bionemo.esm2.model.finetune.dataset import ( - InMemoryPerTokenValueDataset, - InMemoryProteinDataset, - InMemorySingleValueDataset, -) +from bionemo.esm2.model.finetune.finetune_regressor import InMemorySingleValueDataset +from bionemo.esm2.model.finetune.finetune_token_classifier import InMemoryPerTokenValueDataset from bionemo.llm.data import collate from bionemo.llm.data.datamodule import MegatronDataModule +from bionemo.llm.data.types import BertSample from bionemo.llm.utils.datamodule_utils import infer_num_samples Mode = Literal["train", "validation", "test", "predict"] -DATASET_TYPES = Union[InMemoryPerTokenValueDataset, InMemorySingleValueDataset, InMemoryProteinDataset, None] + + +class InMemoryCSVDataset(Dataset): + """An in-memory dataset that tokenize strings into BertSample instances.""" + + def __init__( + self, + data_path: str | os.PathLike, + tokenizer: tokenizer.BioNeMoESMTokenizer = tokenizer.get_tokenizer(), + seed: int = np.random.SeedSequence().entropy, # type: ignore + ): + """Initializes a dataset for single-value regression fine-tuning. + + This is an in-memory dataset that does not apply masking to the sequence. But keeps track of in the + dataset sequences provided. + + Args: + data_path (str | os.PathLike): A path to the CSV file containing sequences. + labels (Optional[Sequence[float | str]]): An optional sequence of labels with 1:1 mapping to sequences. + tokenizer (tokenizer.BioNeMoESMTokenizer, optional): The tokenizer to use. Defaults to tokenizer.get_tokenizer(). + seed: Random seed for reproducibility. This seed is mixed with the index of the sample to retrieve to ensure + that __getitem__ is deterministic, but can be random across different runs. If None, a random seed is + generated. + """ + self.sequences, self.labels = self.load_data(data_path) + + self.seed = seed + self._len = len(self.sequences) + self.tokenizer = tokenizer + + def __len__(self) -> int: + """The size of the dataset.""" + return self._len + + def __getitem__(self, index: int) -> BertSample: + """Obtains the BertSample at the given index.""" + sequence = self.sequences[index] + tokenized_sequence = self._tokenize(sequence) + + label = tokenized_sequence if len(self.labels) == 0 else torch.Tensor([self.labels[index]]) + # Overall mask for a token being masked in some capacity - either mask token, random token, or left as-is + loss_mask = ~torch.isin(tokenized_sequence, Tensor(self.tokenizer.all_special_ids)) + + return { + "text": tokenized_sequence, + "types": torch.zeros_like(tokenized_sequence, dtype=torch.int64), + "attention_mask": torch.ones_like(tokenized_sequence, dtype=torch.int64), + "labels": label, + "loss_mask": loss_mask, + "is_random": torch.zeros_like(tokenized_sequence, dtype=torch.int64), + } + + def load_data(self, csv_path: str | os.PathLike) -> Tuple[Sequence, Sequence]: + """Loads data from a CSV file, returning sequences and optionally labels. + + This method should be implemented by subclasses to process labels for their specific dataset. + + Args: + csv_path (str | os.PathLike): The path to the CSV file containing the data. + The file is expected to have at least one column named 'sequence'. A 'label' column is optional. + + Returns: + Tuple[Sequence, Sequence]: A tuple where the first element is a list of sequences and the second element is + a list of labels. If the 'label' column is not present, an empty list is returned for labels. + """ + df = pd.read_csv(csv_path) + sequences = df["sequences"].tolist() + + if "labels" in df.columns: + labels = df["labels"].tolist() + else: + labels = [] + return sequences, labels + + def _tokenize(self, sequence: str) -> Tensor: + """Tokenize a protein sequence. + + Args: + sequence: The protein sequence. + + Returns: + The tokenized sequence. + """ + tensor = self.tokenizer.encode(sequence, add_special_tokens=True, return_tensors="pt") + return tensor.flatten() # type: ignore + + +DATASET_TYPES = Union[InMemoryPerTokenValueDataset, InMemorySingleValueDataset, InMemoryCSVDataset, None] class ESM2FineTuneDataModule(MegatronDataModule): diff --git a/sub-packages/bionemo-esm2/src/bionemo/esm2/model/finetune/dataset.py b/sub-packages/bionemo-esm2/src/bionemo/esm2/model/finetune/dataset.py deleted file mode 100644 index 542854548d..0000000000 --- a/sub-packages/bionemo-esm2/src/bionemo/esm2/model/finetune/dataset.py +++ /dev/null @@ -1,221 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: LicenseRef-Apache2 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -import os -from typing import Sequence - -import numpy as np -import pandas as pd -import torch -import torch.utils.data -from torch import Tensor -from torch.utils.data import Dataset - -from bionemo.esm2.data import tokenizer -from bionemo.llm.data.collate import MLM_LOSS_IGNORE_INDEX -from bionemo.llm.data.label2id_tokenizer import Label2IDTokenizer -from bionemo.llm.data.types import BertSample - - -__all__: Sequence[str] = ( - "InMemoryProteinDataset", - "InMemorySingleValueDataset", - "InMemoryPerTokenValueDataset", -) - - -class InMemoryProteinDataset(Dataset): - """An in-memory dataset that tokenize strings into BertSample instances.""" - - def __init__( - self, - sequences: pd.Series, - labels: pd.Series | None = None, - tokenizer: tokenizer.BioNeMoESMTokenizer = tokenizer.get_tokenizer(), - seed: int = np.random.SeedSequence().entropy, # type: ignore - ): - """Initializes a dataset of protein sequences. - - This is an in-memory dataset that does not apply masking to the sequence. But keeps track of in the - dataset sequences provided. - - Args: - sequences (pd.Series): A pandas Series containing protein sequences. - labels (pd.Series, optional): A pandas Series containing labels. Defaults to None. - tokenizer (tokenizer.BioNeMoESMTokenizer, optional): The tokenizer to use. Defaults to tokenizer.get_tokenizer(). - seed: Random seed for reproducibility. This seed is mixed with the index of the sample to retrieve to ensure - that __getitem__ is deterministic, but can be random across different runs. If None, a random seed is - generated. - """ - self.sequences = sequences - self.labels = labels - - self.seed = seed - self._len = len(self.sequences) - self.tokenizer = tokenizer - - @classmethod - def from_csv( - cls, csv_path: str | os.PathLike, tokenizer: tokenizer.BioNeMoESMTokenizer = tokenizer.get_tokenizer() - ): - """Class method to create a ProteinDataset instance from a CSV file.""" - df = pd.read_csv(csv_path) - - # Validate presence of required columns - if "sequences" not in df.columns: - raise KeyError("The CSV must contain a 'sequences' column.") - - sequences = df["sequences"] - labels = df["labels"] if "labels" in df.columns else None - return cls(sequences, labels, tokenizer) - - def __len__(self) -> int: - """The size of the dataset.""" - return self._len - - def __getitem__(self, index: int) -> BertSample: - """Obtains the BertSample at the given index.""" - sequence = self.sequences[index] - tokenized_sequence = self._tokenize(sequence) - - label = tokenized_sequence if self.labels is None else self.transform_label(self.labels.iloc[index]) - # Overall mask for a token being masked in some capacity - either mask token, random token, or left as-is - loss_mask = ~torch.isin(tokenized_sequence, Tensor(self.tokenizer.all_special_ids)) - - return { - "text": tokenized_sequence, - "types": torch.zeros_like(tokenized_sequence, dtype=torch.int64), - "attention_mask": torch.ones_like(tokenized_sequence, dtype=torch.int64), - "labels": label, - "loss_mask": loss_mask, - "is_random": torch.zeros_like(tokenized_sequence, dtype=torch.int64), - } - - def _tokenize(self, sequence: str) -> Tensor: - """Tokenize a protein sequence. - - Args: - sequence: The protein sequence. - - Returns: - The tokenized sequence. - """ - tensor = self.tokenizer.encode(sequence, add_special_tokens=True, return_tensors="pt") - return tensor.flatten() # type: ignore - - def transform_label(self, label): - """Transform the label. - - This method should be implemented by subclass if label needs additional transformation. - - Args: - label: label to be transformed - - Returns: - transformed_label - """ - return label - - -class InMemorySingleValueDataset(InMemoryProteinDataset): - """An in-memory dataset that tokenizes strings into BertSample instances.""" - - def __init__( - self, - sequences: pd.Series, - labels: pd.Series | None = None, - tokenizer: tokenizer.BioNeMoESMTokenizer = tokenizer.get_tokenizer(), - seed: int = np.random.SeedSequence().entropy, # type: ignore - ): - """Initializes a dataset for single-value regression fine-tuning. - - This is an in-memory dataset that does not apply masking to the sequence. But keeps track of in the - dataset sequences provided. - - Args: - sequences (pd.Series): A pandas Series containing protein sequences. - labels (pd.Series, optional): A pandas Series containing labels. Defaults to None. - tokenizer (tokenizer.BioNeMoESMTokenizer, optional): The tokenizer to use. Defaults to tokenizer.get_tokenizer(). - seed: Random seed for reproducibility. This seed is mixed with the index of the sample to retrieve to ensure - that __getitem__ is deterministic, but can be random across different runs. If None, a random seed is - generated. - """ - super().__init__(sequences, labels, tokenizer, seed) - - def transform_label(self, label: float) -> Tensor: - """Transform the regression label. - - Args: - label: regression value - - Returns: - tokenized label - """ - return torch.tensor([label], dtype=torch.float) - - -class InMemoryPerTokenValueDataset(InMemoryProteinDataset): - """An in-memory dataset of labeled strings, which are tokenized on demand.""" - - def __init__( - self, - sequences: pd.Series, - labels: pd.Series | None = None, - tokenizer: tokenizer.BioNeMoESMTokenizer = tokenizer.get_tokenizer(), - seed: int = np.random.SeedSequence().entropy, # type: ignore - ): - """Initializes a dataset for per-token classification fine-tuning. - - This is an in-memory dataset that does not apply masking to the sequence. But keeps track of in the - dataset sequences provided. - - Args: - sequences (pd.Series): A pandas Series containing protein sequences. - labels (pd.Series, optional): A pandas Series containing labels. Defaults to None. - tokenizer (tokenizer.BioNeMoESMTokenizer, optional): The tokenizer to use. Defaults to tokenizer.get_tokenizer(). - seed: Random seed for reproducibility. This seed is mixed with the index of the sample to retrieve to ensure - that __getitem__ is deterministic, but can be random across different runs. If None, a random seed is - generated. - """ - super().__init__(sequences, labels, tokenizer, seed) - label_tokenizer = Label2IDTokenizer() - self.label_tokenizer = label_tokenizer.build_vocab("CHE") - self.label_cls_eos_id = MLM_LOSS_IGNORE_INDEX - - def transform_label(self, label: str) -> Tensor: - """Transform the sequence label by tokenizing them. - - This method tokenizes the secondary structure token sequences. - - Args: - label: secondary structure token sequences to be transformed - - Returns: - tokenized label - """ - label_ids = torch.tensor(self.label_tokenizer.text_to_ids(label)) - - # # for multi-label classification with BCEWithLogitsLoss - # tokenized_labels = torch.nn.functional.one_hot(label_ids, num_classes=self.label_tokenizer.vocab_size) - # cls_eos = torch.full((1, self.label_tokenizer.vocab_size), self.label_cls_eos_id, dtype=tokenized_labels.dtype) - - # for multi-class (mutually exclusive) classification with CrossEntropyLoss - tokenized_labels = label_ids - cls_eos = torch.tensor([self.label_cls_eos_id], dtype=tokenized_labels.dtype) - - # add cls / eos label ids with padding value -100 to have the same shape as tokenized_sequence - labels = torch.cat((cls_eos, tokenized_labels, cls_eos)) - return labels diff --git a/sub-packages/bionemo-esm2/src/bionemo/esm2/model/finetune/finetune_regressor.py b/sub-packages/bionemo-esm2/src/bionemo/esm2/model/finetune/finetune_regressor.py index 27d3375864..f63a194190 100644 --- a/sub-packages/bionemo-esm2/src/bionemo/esm2/model/finetune/finetune_regressor.py +++ b/sub-packages/bionemo-esm2/src/bionemo/esm2/model/finetune/finetune_regressor.py @@ -17,13 +17,17 @@ from dataclasses import dataclass, field from typing import Dict, List, Sequence, Tuple, Type +import numpy as np import torch from megatron.core import parallel_state from megatron.core.transformer.module import MegatronModule from megatron.core.transformer.transformer_config import TransformerConfig from torch import Tensor +from torch.utils.data import Dataset from bionemo.esm2.api import ESM2GenericConfig, ESM2Model +from bionemo.esm2.data import tokenizer +from bionemo.llm.data.types import BertSample from bionemo.llm.model.biobert.model import BioBertOutput from bionemo.llm.model.loss import BERTMLMLossWithReduction, PerTokenLossDict, SameSizeLossDict from bionemo.llm.utils import iomixin_utils as iom @@ -37,6 +41,7 @@ "MegatronMLPHead", "ESM2FineTuneSeqModel", "ESM2FineTuneSeqConfig", + "InMemorySingleValueDataset", ) @@ -173,3 +178,61 @@ class ESM2FineTuneSeqConfig( def get_loss_reduction_class(self) -> Type[RegressorLossReduction]: """Returns RegressorLossReduction class.""" return RegressorLossReduction + + +class InMemorySingleValueDataset(Dataset): + """An in-memory dataset that tokenizes strings into BertSample instances.""" + + def __init__( + self, + data: Sequence[Tuple[str, float]], + tokenizer: tokenizer.BioNeMoESMTokenizer = tokenizer.get_tokenizer(), + seed: int = np.random.SeedSequence().entropy, # type: ignore + ): + """Initializes a dataset for single-value regression fine-tuning. + + This is an in-memory dataset that does not apply masking to the sequence. + + Args: + data (Sequence[Tuple[str, float]]): A sequence of tuples containing the sequence and target data. + tokenizer (tokenizer.BioNeMoESMTokenizer, optional): The tokenizer to use. Defaults to tokenizer.get_tokenizer(). + seed: Random seed for reproducibility. This seed is mixed with the index of the sample to retrieve to ensure + that __getitem__ is deterministic, but can be random across different runs. If None, a random seed is + generated. + """ + self.data = data + self.seed = seed + self._len = len(self.data) + self.tokenizer = tokenizer + + def __len__(self) -> int: + """The size of the dataset.""" + return self._len + + def __getitem__(self, index: int) -> BertSample: + """Obtains the BertSample at the given index.""" + sequence, target = self.data[index] + tokenized_sequence = self._tokenize(sequence) + # Overall mask for a token being masked in some capacity - either mask token, random token, or left as-is + loss_mask = ~torch.isin(tokenized_sequence, Tensor(self.tokenizer.all_special_ids)) + + return { + "text": tokenized_sequence, + "types": torch.zeros_like(tokenized_sequence, dtype=torch.int64), + "attention_mask": torch.ones_like(tokenized_sequence, dtype=torch.int64), + "labels": torch.tensor([target], dtype=torch.float), + "loss_mask": loss_mask, + "is_random": torch.zeros_like(tokenized_sequence, dtype=torch.int64), + } + + def _tokenize(self, sequence: str) -> Tensor: + """Tokenize a protein sequence. + + Args: + sequence: The protein sequence. + + Returns: + The tokenized sequence. + """ + tensor = self.tokenizer.encode(sequence, add_special_tokens=True, return_tensors="pt") + return tensor.flatten() # type: ignore diff --git a/sub-packages/bionemo-esm2/src/bionemo/esm2/model/finetune/finetune_token_classifier.py b/sub-packages/bionemo-esm2/src/bionemo/esm2/model/finetune/finetune_token_classifier.py index fe67cf2ac8..b0991669d8 100644 --- a/sub-packages/bionemo-esm2/src/bionemo/esm2/model/finetune/finetune_token_classifier.py +++ b/sub-packages/bionemo-esm2/src/bionemo/esm2/model/finetune/finetune_token_classifier.py @@ -15,15 +15,21 @@ from dataclasses import dataclass, field -from typing import Dict, List, Sequence, Tuple, Type +from typing import List, Sequence, Tuple, Type, TypedDict +import numpy as np import torch from megatron.core import parallel_state from megatron.core.transformer.module import MegatronModule from megatron.core.transformer.transformer_config import TransformerConfig from torch import Tensor +from torch.utils.data import Dataset from bionemo.esm2.api import ESM2GenericConfig, ESM2Model +from bionemo.esm2.data import tokenizer +from bionemo.llm.data.collate import MLM_LOSS_IGNORE_INDEX +from bionemo.llm.data.label2id_tokenizer import Label2IDTokenizer +from bionemo.llm.data.types import BertSample from bionemo.llm.model.biobert.model import BioBertOutput from bionemo.llm.model.loss import BERTMLMLossWithReduction, PerTokenLossDict, SameSizeLossDict from bionemo.llm.utils import iomixin_utils as iom @@ -38,9 +44,25 @@ "MegatronConvNetHead", "ESM2FineTuneTokenModel", "ESM2FineTuneTokenConfig", + "InMemoryPerTokenValueDataset", + "ClassifierInput", + "Esm2FineTuneTokenOutput", ) +class ClassifierInput(TypedDict): + """Used as input in the ClassifierLossReduction's forward method.""" + + labels: Tensor + loss_mask: Tensor + + +class Esm2FineTuneTokenOutput(BioBertOutput): + """Inference output from ESM2FineTuneTokenModel.""" + + classification_output: Tensor + + class ClassifierLossReduction(BERTMLMLossWithReduction): """A class for calculating the cross entropy loss of classification output. @@ -48,7 +70,7 @@ class ClassifierLossReduction(BERTMLMLossWithReduction): """ def forward( - self, batch: Dict[str, Tensor], forward_out: Dict[str, Tensor] + self, batch: ClassifierInput, forward_out: Esm2FineTuneTokenOutput ) -> Tuple[Tensor, PerTokenLossDict | SameSizeLossDict]: """Calculates the loss within a micro-batch. A micro-batch is a batch of data on a single GPU. @@ -137,9 +159,9 @@ def __init__(self, config, *args, include_hiddens: bool = False, post_process: b # if we are doing post process (eg pipeline last stage) then we need to add the output layers self.classification_head = MegatronConvNetHead(config) - def forward(self, *args, **kwargs) -> Tensor | BioBertOutput: + def forward(self, *args, **kwargs) -> Tensor | BioBertOutput | Esm2FineTuneTokenOutput: """Inference.""" - output = super().forward(*args, **kwargs) + output: Tensor | BioBertOutput | Esm2FineTuneTokenOutput = super().forward(*args, **kwargs) # Stop early if we are not in post_process mode (for example if we are in the middle of model parallelism) if not self.post_process: return output # we are not at the last pipeline stage so just return what the parent has @@ -181,3 +203,80 @@ class ESM2FineTuneTokenConfig( def get_loss_reduction_class(self) -> Type[ClassifierLossReduction]: """The loss function type.""" return ClassifierLossReduction + + +class InMemoryPerTokenValueDataset(Dataset): + """An in-memory dataset of labeled strings, which are tokenized on demand.""" + + def __init__( + self, + data: Sequence[Tuple[str, str]], + tokenizer: tokenizer.BioNeMoESMTokenizer = tokenizer.get_tokenizer(), + seed: int = np.random.SeedSequence().entropy, # type: ignore + ): + """Initializes a dataset for per-token classification fine-tuning. + + This is an in-memory dataset that does not apply masking to the sequence. + + Args: + data: A sequence of tuples containing the sequence and target data. + tokenizer: The tokenizer to use. Defaults to tokenizer.get_tokenizer(). + seed: Random seed for reproducibility. This seed is mixed with the index of the sample to retrieve to + ensure that __getitem__ is deterministic, but can be random across different runs. If None, a random + seed is generated. + """ + self.data = data + self.seed = seed + self._len = len(self.data) + self.tokenizer = tokenizer + label_tokenizer = Label2IDTokenizer() + self.label_tokenizer = label_tokenizer.build_vocab("CHE") + self.label_cls_eos_id = MLM_LOSS_IGNORE_INDEX + + def __len__(self) -> int: + """Length of dataset.""" + return self._len + + def __getitem__(self, index: int) -> BertSample: + """Gets a BertSample associated to the supplied index.""" + sequence, target = self.data[index] + tokenized_sequence = self._tokenize(sequence) + # Overall mask for a token being masked in some capacity - either mask token, random token, or left as-is + loss_mask = ~torch.isin(tokenized_sequence, torch.tensor(self.tokenizer.all_special_ids)) + labels = self._tokenize_labels(target) + + return { + "text": tokenized_sequence, + "types": torch.zeros_like(tokenized_sequence, dtype=torch.int64), + "attention_mask": torch.ones_like(tokenized_sequence, dtype=torch.int64), + "labels": labels, + "loss_mask": loss_mask, + "is_random": torch.zeros_like(tokenized_sequence, dtype=torch.int64), + } + + def _tokenize_labels(self, labels_sequence: str) -> Tensor: + label_ids = torch.tensor(self.label_tokenizer.text_to_ids(labels_sequence)) + + # # for multi-label classification with BCEWithLogitsLoss + # tokenized_labels = torch.nn.functional.one_hot(label_ids, num_classes=self.label_tokenizer.vocab_size) + # cls_eos = torch.full((1, self.label_tokenizer.vocab_size), self.label_cls_eos_id, dtype=tokenized_labels.dtype) + + # for multi-class (mutually exclusive) classification with CrossEntropyLoss + tokenized_labels = label_ids + cls_eos = torch.tensor([self.label_cls_eos_id], dtype=tokenized_labels.dtype) + + # add cls / eos label ids with padding value -100 to have the same shape as tokenized_sequence + labels = torch.cat((cls_eos, tokenized_labels, cls_eos)) + return labels + + def _tokenize(self, sequence: str) -> Tensor: + """Tokenize a protein sequence. + + Args: + sequence: The protein sequence. + + Returns: + The tokenized sequence. + """ + tensor = self.tokenizer.encode(sequence, add_special_tokens=True, return_tensors="pt") + return tensor.flatten() # type: ignore diff --git a/sub-packages/bionemo-esm2/src/bionemo/esm2/model/finetune/train.py b/sub-packages/bionemo-esm2/src/bionemo/esm2/model/finetune/train.py new file mode 100644 index 0000000000..638729e3f4 --- /dev/null +++ b/sub-packages/bionemo-esm2/src/bionemo/esm2/model/finetune/train.py @@ -0,0 +1,189 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import tempfile +from pathlib import Path +from typing import Sequence, Tuple + +import lightning.pytorch as pl +from lightning.pytorch.callbacks import Callback, RichModelSummary +from lightning.pytorch.loggers import TensorBoardLogger +from megatron.core.optimizer.optimizer_config import OptimizerConfig +from nemo import lightning as nl +from nemo.collections import llm as nllm +from nemo.lightning import resume +from nemo.lightning.nemo_logger import NeMoLogger +from nemo.lightning.pytorch import callbacks as nl_callbacks +from nemo.lightning.pytorch.callbacks.model_transform import ModelTransform +from nemo.lightning.pytorch.callbacks.peft import PEFT +from nemo.lightning.pytorch.optim.megatron import MegatronOptimizerModule + +from bionemo.core.data.load import load +from bionemo.esm2.api import ESM2GenericConfig +from bionemo.esm2.data.tokenizer import BioNeMoESMTokenizer, get_tokenizer +from bionemo.esm2.model.finetune.datamodule import ESM2FineTuneDataModule +from bionemo.esm2.model.finetune.finetune_regressor import ESM2FineTuneSeqConfig, InMemorySingleValueDataset +from bionemo.llm.model.biobert.lightning import biobert_lightning_module + + +__all__: Sequence[str] = ("train_model",) + + +def train_model( + experiment_name: str, + experiment_dir: Path, + config: ESM2GenericConfig, + data_module: pl.LightningDataModule, + n_steps_train: int, + metric_tracker: Callback | None = None, + tokenizer: BioNeMoESMTokenizer = get_tokenizer(), + peft: PEFT | None = None, + _use_rich_model_summary: bool = True, +) -> Tuple[Path, Callback | None, nl.Trainer]: + """Trains a BioNeMo ESM2 model using PyTorch Lightning. + + Parameters: + experiment_name: The name of the experiment. + experiment_dir: The directory where the experiment will be saved. + config: The configuration for the ESM2 model. + data_module: The data module for training and validation. + n_steps_train: The number of training steps. + metric_tracker: Optional callback to track metrics + tokenizer: The tokenizer to use. Defaults to `get_tokenizer()`. + peft: The PEFT (Parameter-Efficient Fine-Tuning) module. Defaults to None. + _use_rich_model_summary: Whether to use the RichModelSummary callback, omitted in our test suite until + https://nvbugspro.nvidia.com/bug/4959776 is resolved. Defaults to True. + + Returns: + A tuple containing the path to the saved checkpoint, a MetricTracker + object, and the PyTorch Lightning Trainer object. + """ + checkpoint_callback = nl_callbacks.ModelCheckpoint( + save_last=True, + save_on_train_epoch_end=True, + monitor="reduced_train_loss", # TODO find out how to get val_loss logged and use "val_loss", + every_n_train_steps=n_steps_train // 2, + always_save_context=True, # Enables the .nemo file-like checkpointing where all IOMixins are under SerDe + ) + + # Setup the logger and train the model + nemo_logger = NeMoLogger( + log_dir=str(experiment_dir), + name=experiment_name, + tensorboard=TensorBoardLogger(save_dir=experiment_dir, name=experiment_name), + ckpt=checkpoint_callback, + ) + # Needed so that the trainer can find an output directory for the profiler + # ckpt_path needs to be a string for SerDe + optimizer = MegatronOptimizerModule( + config=OptimizerConfig( + lr=5e-4, + optimizer="adam", + use_distributed_optimizer=True, + fp16=config.fp16, + bf16=config.bf16, + ) + ) + module = biobert_lightning_module(config=config, tokenizer=tokenizer, optimizer=optimizer, model_transform=peft) + + strategy = nl.MegatronStrategy( + tensor_model_parallel_size=1, + pipeline_model_parallel_size=1, + ddp="megatron", + find_unused_parameters=True, + enable_nemo_ckpt_io=True, + ) + + if _use_rich_model_summary: + # RichModelSummary is not used in the test suite until https://nvbugspro.nvidia.com/bug/4959776 is resolved due + # to errors with serialization / deserialization. + callbacks: list[Callback] = [RichModelSummary(max_depth=4)] + else: + callbacks = [] + + if metric_tracker is not None: + callbacks.append(metric_tracker) + if peft is not None: + callbacks.append( + ModelTransform() + ) # Callback needed for PEFT fine-tuning using NeMo2, i.e. biobert_lightning_module(model_transform=peft). + + trainer = nl.Trainer( + accelerator="gpu", + devices=1, + strategy=strategy, + limit_val_batches=2, + val_check_interval=n_steps_train // 2, + max_steps=n_steps_train, + num_nodes=1, + log_every_n_steps=n_steps_train // 2, + callbacks=callbacks, + plugins=nl.MegatronMixedPrecision(precision="bf16-mixed"), + ) + nllm.train( + model=module, + data=data_module, + trainer=trainer, + log=nemo_logger, + resume=resume.AutoResume( + resume_if_exists=True, # Looks for the -last checkpoint to continue training. + resume_ignore_no_checkpoint=True, # When false this will throw an error with no existing checkpoint. + ), + ) + ckpt_path = Path(checkpoint_callback.last_model_path.replace(".ckpt", "")) + return ckpt_path, metric_tracker, trainer + + +if __name__ == "__main__": + # set the results directory + experiment_results_dir = tempfile.TemporaryDirectory().name + + # create a List[Tuple] with (sequence, target) values + artificial_sequence_data = [ + "TLILGWSDKLGSLLNQLAIANESLGGGTIAVMAERDKEDMELDIGKMEFDFKGTSVI", + "LYSGDHSTQGARFLRDLAENTGRAEYELLSLF", + "GRFNVWLGGNESKIRQVLKAVKEIGVSPTLFAVYEKN", + "DELTALGGLLHDIGKPVQRAGLYSGDHSTQGARFLRDLAENTGRAEYELLSLF", + "KLGSLLNQLAIANESLGGGTIAVMAERDKEDMELDIGKMEFDFKGTSVI", + "LFGAIGNAISAIHGQSAVEELVDAFVGGARISSAFPYSGDTYYLPKP", + "LGGLLHDIGKPVQRAGLYSGDHSTQGARFLRDLAENTGRAEYELLSLF", + "LYSGDHSTQGARFLRDLAENTGRAEYELLSLF", + "ISAIHGQSAVEELVDAFVGGARISSAFPYSGDTYYLPKP", + "SGSKASSDSQDANQCCTSCEDNAPATSYCVECSEPLCETCVEAHQRVKYTKDHTVRSTGPAKT", + ] + data = [(seq, len(seq) / 100.0) for seq in artificial_sequence_data] + + # we are training and validating on the same dataset for simplicity + dataset = InMemorySingleValueDataset(data) + data_module = ESM2FineTuneDataModule(train_dataset=dataset, valid_dataset=dataset) + + experiment_name = "finetune_regressor" + n_steps_train = 50 + seed = 42 + + # To download a 650M pre-trained ESM2 model + pretrain_ckpt_path = load("esm2/650m:2.0") + + config = ESM2FineTuneSeqConfig(initial_ckpt_path=str(pretrain_ckpt_path)) + + checkpoint, metrics, trainer = train_model( + experiment_name=experiment_name, + experiment_dir=Path(experiment_results_dir), # new checkpoint will land in a subdir of this + config=config, # same config as before since we are just continuing training + data_module=data_module, + n_steps_train=n_steps_train, + ) + print(f"Experiment completed with checkpoint stored at {checkpoint}") diff --git a/sub-packages/bionemo-esm2/src/bionemo/esm2/scripts/finetune_esm2.py b/sub-packages/bionemo-esm2/src/bionemo/esm2/scripts/finetune_esm2.py deleted file mode 100644 index 1b35f169ff..0000000000 --- a/sub-packages/bionemo-esm2/src/bionemo/esm2/scripts/finetune_esm2.py +++ /dev/null @@ -1,635 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: LicenseRef-Apache2 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import argparse -from pathlib import Path -from typing import Dict, List, Optional, Sequence, Tuple, Type, get_args - -from lightning.pytorch.callbacks import Callback, LearningRateMonitor, RichModelSummary -from megatron.core.distributed import DistributedDataParallelConfig -from megatron.core.optimizer import OptimizerConfig -from nemo import lightning as nl -from nemo.collections import llm -from nemo.lightning import resume -from nemo.lightning.pytorch import callbacks as nl_callbacks -from nemo.lightning.pytorch.optim import MegatronOptimizerModule - -from bionemo.core.utils.dtypes import PrecisionTypes, get_autocast_dtype -from bionemo.esm2.data.tokenizer import get_tokenizer -from bionemo.esm2.model.finetune.datamodule import ESM2FineTuneDataModule -from bionemo.esm2.model.finetune.dataset import ( - InMemoryPerTokenValueDataset, - InMemoryProteinDataset, - InMemorySingleValueDataset, -) -from bionemo.esm2.model.finetune.finetune_regressor import ESM2FineTuneSeqConfig -from bionemo.esm2.model.finetune.finetune_token_classifier import ESM2FineTuneTokenConfig -from bionemo.llm.model.biobert.lightning import biobert_lightning_module -from bionemo.llm.model.biobert.model import BioBertConfig -from bionemo.llm.utils.datamodule_utils import float_or_int_or_none, infer_global_batch_size -from bionemo.llm.utils.logger_utils import WandbConfig, setup_nemo_lightning_logger - - -__all__: Sequence[str] = ("train_model", "finetune_esm2_entrypoint", "get_parser") - - -SUPPORTED_CONFIGS = { - "ESM2FineTuneSeqConfig": ESM2FineTuneSeqConfig, - "ESM2FineTuneTokenConfig": ESM2FineTuneTokenConfig, -} - -SUPPORTED_DATASETS = { - "InMemoryProteinDataset": InMemoryProteinDataset, - "InMemorySingleValueDataset": InMemorySingleValueDataset, - "InMemoryPerTokenValueDataset": InMemoryPerTokenValueDataset, -} - - -def train_model( - train_data_path: Path, - valid_data_path: Path, - num_nodes: int, - devices: int, - min_seq_length: Optional[int], - max_seq_length: int, - result_dir: Path, - num_steps: int, - limit_val_batches: int, - val_check_interval: int, - log_every_n_steps: Optional[int], - num_dataset_workers: int, - lr: float, - micro_batch_size: int, - accumulate_grad_batches: int, - experiment_name: str, - resume_if_exists: bool, - precision: PrecisionTypes, - wandb_entity: Optional[str] = None, - wandb_project: Optional[str] = None, - wandb_offline: bool = False, - wandb_tags: Optional[List[str]] = None, - wandb_group: Optional[str] = None, - wandb_id: Optional[str] = None, - wandb_anonymous: Optional[bool] = False, - wandb_log_model: bool = False, - pipeline_model_parallel_size: int = 1, - tensor_model_parallel_size: int = 1, - create_tensorboard_logger: bool = False, - restore_from_checkpoint_path: Optional[str] = None, - save_last_checkpoint: bool = True, - metric_to_monitor_for_checkpoints: str = "val_loss", - save_top_k: int = 2, - nsys_profiling: bool = False, - nsys_start_step: int = 0, - nsys_end_step: Optional[int] = None, - nsys_ranks: List[int] = [0], - dataset_class: Type[InMemoryProteinDataset] = InMemorySingleValueDataset, - config_class: Type[BioBertConfig] = ESM2FineTuneSeqConfig, - metric_tracker: Callback | None = None, - overlap_grad_reduce: bool = True, - overlap_param_gather: bool = True, - average_in_collective: bool = True, - grad_reduce_in_fp32: bool = False, -) -> Tuple[Path, Callback | None, nl.Trainer]: - """Train an ESM2 model on UR data. - - Args: - train_data_path (Path): path to train CSV - valid_data_path (Path): path to validation CSV - num_nodes (int): Number of nodes to run on - devices (int): number of devices - min_seq_length (Optional[int]): minimum sequence length - max_seq_length (int): maximum sequence length - result_dir (Path): directory to store results, logs and checkpoints - num_steps (int): number of steps to train the model for - limit_val_batches (int): limit the number of validation global batches to this many - val_check_interval (int): number of steps to periodically check the validation loss - log_every_n_steps (Optional[int]): log every n steps - num_dataset_workers (int): number of dataset workers - lr (float): learning rate - micro_batch_size (int): micro batch size, from this and parallelism settings we infer the global batch size - accumulate_grad_batches (int): number of batches to accumulate gradients for - experiment_name (str): experiment name, this is the name used for the wandb run, and the sub-directory of the - result_dir that stores the logs and checkpoints. - resume_if_exists (bool): attempt to resume if the checkpoint exists [FIXME @skothenhill this doesn't work yet] - precision (PrecisionTypes): Precision type for training (e.g., float16, float32) - wandb_entity (Optional[str]): The team posting this run (default: your username or your default team) - wandb_project (Optional[str]): The name of the project to which this run will belong - wandb_offline (bool): Run offline (data can be streamed later to wandb servers). - wandb_tags (Optional[List[str]]): Tags associated with this run - wandb_group (Optional[str]): A unique string shared by all runs in a given group - wandb_id (Optional[str]): Sets the version, mainly used to resume a previous run - wandb_anonymous (Optional[bool]): Enables or explicitly disables anonymous logging - wandb_log_model (bool): Save checkpoints in wandb dir to upload on W&B servers - pipeline_model_parallel_size (int): pipeline model parallel size - tensor_model_parallel_size (int): tensor model parallel size - create_tensorboard_logger (bool): create the tensorboard logger - restore_from_checkpoint_path (Optional[str]): If set, restores the model from the directory passed in. Expects the - checkpoint to be created by using the ModelCheckpoint class and always_save_context=True. - save_last_checkpoint (bool): whether to save the last checkpoint - metric_to_monitor_for_checkpoints (str): metric to monitor for checkpoints - save_top_k (int): number of top checkpoints to save - nsys_profiling (bool): whether to enable nsys profiling - nsys_start_step (int): start step for nsys profiling - nsys_end_step (Optional[int]): end step for nsys profiling - nsys_ranks (List[int]): ranks for nsys profiling - dataset_class (Type[InMemoryProteinDataset]): The dataset class for loading the data from a CSV file - config_class (Type[BioBertConfig]): The config class for configuring the model using checkpoint provided - metric_tracker: Optional callback to track metrics (used for testing) - overlap_grad_reduce (bool): overlap gradient reduction - overlap_param_gather (bool): overlap parameter gather - average_in_collective (bool): average in collective - grad_reduce_in_fp32 (bool): gradient reduction in fp32 - """ - # Create the result directory if it does not exist. - result_dir.mkdir(parents=True, exist_ok=True) - - # Setup the strategy and trainer - global_batch_size = infer_global_batch_size( - micro_batch_size=micro_batch_size, - num_nodes=num_nodes, - devices=devices, - accumulate_grad_batches=accumulate_grad_batches, - tensor_model_parallel_size=tensor_model_parallel_size, - pipeline_model_parallel_size=pipeline_model_parallel_size, - ) - - strategy = nl.MegatronStrategy( - tensor_model_parallel_size=tensor_model_parallel_size, - pipeline_model_parallel_size=pipeline_model_parallel_size, - ddp=DistributedDataParallelConfig( - check_for_nan_in_grad=True, - overlap_grad_reduce=overlap_grad_reduce, - overlap_param_gather=overlap_param_gather, - average_in_collective=average_in_collective, - grad_reduce_in_fp32=grad_reduce_in_fp32, - use_distributed_optimizer=True, - ), - find_unused_parameters=True, - gradient_as_bucket_view=True, - ckpt_include_optimizer=True, - ckpt_async_save=True, - ckpt_parallel_load=True, - ) - - # for wandb integration - # Please refer to https://pytorch-lightning.readthedocs.io/en/0.7.6/api/lightning.pytorch.loggers.html" - wandb_config: Optional[WandbConfig] = ( - None - if wandb_project is None - else WandbConfig( - offline=wandb_offline, - project=wandb_project, - entity=wandb_entity, - tags=wandb_tags, - group=wandb_group, - id=wandb_id, - anonymous=wandb_anonymous, - log_model=wandb_log_model, - ) - ) - - callbacks = [ - RichModelSummary(max_depth=4), - LearningRateMonitor(), - nl_callbacks.PreemptionCallback(), - ] - if metric_tracker is not None: - callbacks.append(metric_tracker) - if nsys_profiling: - if nsys_end_step is None: - nsys_end_step = num_steps - callbacks.append( - nl_callbacks.NsysCallback( - start_step=nsys_start_step, end_step=nsys_end_step, ranks=nsys_ranks, gen_shape=True - ) - ) - - trainer = nl.Trainer( - devices=devices, - max_steps=num_steps, - accelerator="gpu", - strategy=strategy, - limit_val_batches=limit_val_batches, # This controls upsampling and downsampling - val_check_interval=val_check_interval, - log_every_n_steps=log_every_n_steps, - num_nodes=num_nodes, - callbacks=callbacks, - plugins=nl.MegatronMixedPrecision( - precision=precision, - params_dtype=get_autocast_dtype(precision), - pipeline_dtype=get_autocast_dtype(precision), - grad_reduce_in_fp32=grad_reduce_in_fp32, - autocast_enabled=False, - ), - ) - - tokenizer = get_tokenizer() - - # Initialize the data module. - train_dataset = dataset_class.from_csv(train_data_path) - valid_dataset = dataset_class.from_csv(valid_data_path) - - data_module = ESM2FineTuneDataModule( - train_dataset=train_dataset, - valid_dataset=valid_dataset, - global_batch_size=global_batch_size, - micro_batch_size=micro_batch_size, - min_seq_length=min_seq_length, - max_seq_length=max_seq_length, - num_workers=num_dataset_workers, - tokenizer=tokenizer, - ) - # Configure the model - config = config_class( - params_dtype=get_autocast_dtype(precision), - pipeline_dtype=get_autocast_dtype(precision), - autocast_dtype=get_autocast_dtype(precision), # setting this speeds things up a lot - tensor_model_parallel_size=tensor_model_parallel_size, - pipeline_model_parallel_size=pipeline_model_parallel_size, - initial_ckpt_path=str(restore_from_checkpoint_path), - # initial_ckpt_skip_keys_with_these_prefixes=[], # load everything from the checkpoint. - ) - - optimizer = MegatronOptimizerModule( - config=OptimizerConfig( - lr=lr, - optimizer="adam", # fused_adam not supported - use_distributed_optimizer=True, - weight_decay=0.01, - adam_beta1=0.9, - adam_beta2=0.98, - ) - ) - - module = biobert_lightning_module(config=config, tokenizer=tokenizer, optimizer=optimizer) - - # Configure our custom Checkpointer - checkpoint_callback = nl_callbacks.ModelCheckpoint( - save_last=save_last_checkpoint, - monitor=metric_to_monitor_for_checkpoints, # "val_loss", - save_top_k=save_top_k, - every_n_train_steps=val_check_interval, - always_save_context=True, # Enables the .nemo file-like checkpointing where all IOMixins are under SerDe - filename="checkpoint-{step}-{consumed_samples}", # Including step and consumed_samples in the checkpoint filename prevents duplicate filenames and bugs related to this. - ) - - # Setup the logger and train the model - nemo_logger = setup_nemo_lightning_logger( - root_dir=result_dir, - name=experiment_name, - initialize_tensorboard_logger=create_tensorboard_logger, - wandb_config=wandb_config, - ckpt_callback=checkpoint_callback, - ) - - llm.train( - model=module, - data=data_module, - trainer=trainer, - log=nemo_logger, - resume=resume.AutoResume( - resume_if_exists=resume_if_exists, # Looks for the -last checkpoint to continue training. - resume_ignore_no_checkpoint=True, # When false this will throw an error with no existing checkpoint. - ), - ) - ckpt_path = Path(checkpoint_callback.last_model_path.replace(".ckpt", "")) - return ckpt_path, metric_tracker, trainer - - -def finetune_esm2_entrypoint(): - """Entrypoint for running ESM2 finetuning.""" - # 1. get arguments - parser = get_parser() - args = parser.parse_args() - # 2. Call pretrain with args - train_model( - train_data_path=args.train_data_path, - valid_data_path=args.valid_data_path, - num_nodes=args.num_nodes, - devices=args.num_gpus, - min_seq_length=args.min_seq_length, - max_seq_length=args.max_seq_length, - result_dir=args.result_dir, - wandb_entity=args.wandb_entity, - wandb_project=args.wandb_project, - wandb_tags=args.wandb_tags, - wandb_group=args.wandb_group, - wandb_id=args.wandb_id, - wandb_anonymous=args.wandb_anonymous, - wandb_log_model=args.wandb_log_model, - wandb_offline=args.wandb_offline, - num_steps=args.num_steps, - limit_val_batches=args.limit_val_batches, - val_check_interval=args.val_check_interval, - log_every_n_steps=args.log_every_n_steps, - num_dataset_workers=args.num_dataset_workers, - lr=args.lr, - micro_batch_size=args.micro_batch_size, - pipeline_model_parallel_size=args.pipeline_model_parallel_size, - tensor_model_parallel_size=args.tensor_model_parallel_size, - accumulate_grad_batches=args.accumulate_grad_batches, - precision=args.precision, - experiment_name=args.experiment_name, - resume_if_exists=args.resume_if_exists, - restore_from_checkpoint_path=args.restore_from_checkpoint_path, - save_last_checkpoint=args.save_last_checkpoint, - metric_to_monitor_for_checkpoints=args.metric_to_monitor_for_checkpoints, - save_top_k=args.save_top_k, - nsys_profiling=args.nsys_profiling, - nsys_start_step=args.nsys_start_step, - nsys_end_step=args.nsys_end_step, - nsys_ranks=args.nsys_ranks, - dataset_class=args.dataset_class, - config_class=args.config_class, - overlap_grad_reduce=not args.no_overlap_grad_reduce, - overlap_param_gather=not args.no_overlap_param_gather, - average_in_collective=not args.no_average_in_collective, - grad_reduce_in_fp32=args.grad_reduce_in_fp32, - ) - - -def get_parser(): - """Return the cli parser for this tool.""" - # TODO migrate to hydra config - # Parse the arguments and pull them out into local variables for ease of future refactor to a - # config management system. - parser = argparse.ArgumentParser(description="Pretrain ESM2 with UR data.") - parser.add_argument( - "--train-data-path", - type=Path, - required=True, - help="Path to the train data CSV file", - ) - parser.add_argument( - "--valid-data-path", - type=Path, - required=True, - help="Path to the valid data CSV file", - ) - parser.add_argument( - "--precision", - type=str, - choices=get_args(PrecisionTypes), - required=False, - default="bf16-mixed", - help="Precision type to use for training.", - ) - parser.add_argument( - "--lr", - type=float, - required=False, - default=4e-4, - help="Learning rate for training. Default is 4e-4", - ) - parser.add_argument( - "--create-tensorboard-logger", action="store_true", default=False, help="Create a tensorboard logger." - ) - # FIXME (@skothenhill) figure out how checkpointing and resumption should work with the new nemo trainer - parser.add_argument( - "--resume-if-exists", action="store_true", default=False, help="Resume training if a checkpoint exists." - ) - parser.add_argument( - "--result-dir", type=Path, required=False, default=Path("./results"), help="Path to the result directory." - ) - parser.add_argument("--experiment-name", type=str, required=False, default="esm2", help="Name of the experiment.") - - parser.add_argument("--wandb-entity", type=str, default=None, help="The team posting this run") - parser.add_argument("--wandb-project", type=str, default=None, help="Wandb project name ") - parser.add_argument("--wandb-tags", nargs="+", type=str, default=None, help="Tags associated with this run") - parser.add_argument( - "--wandb-group", type=str, default=None, help="A unique string shared by all runs in a given group" - ) - parser.add_argument( - "--wandb-id", type=str, default=None, help="Sets the version, mainly used to resume a previous run" - ) - parser.add_argument( - "--wandb-anonymous", action="store_true", help="Enable or explicitly disable anonymous logging" - ) - parser.add_argument( - "--wandb-log-model", action="store_true", help="Save checkpoints in wandb dir to upload on W&B servers" - ) - parser.add_argument("--wandb-offline", action="store_true", help="Use wandb in offline mode") - parser.add_argument( - "--num-gpus", - type=int, - required=False, - default=1, - help="Number of GPUs to use for training. Default is 1.", - ) - parser.add_argument( - "--num-nodes", - type=int, - required=False, - default=1, - help="Number of nodes to use for training. Default is 1.", - ) - parser.add_argument( - "--num-steps", - type=int, - required=False, - default=500000, - help="Number of steps to use for training. Default is 500000.", - ) - parser.add_argument( - "--num-dataset-workers", - type=int, - required=False, - default=1, - help="Number of workers to use for training. Default is 1.", - ) - parser.add_argument( - "--val-check-interval", - type=int, - required=False, - default=10000, - help="Number of steps between validation. Default is 10000.", - ) - parser.add_argument( - "--log-every-n-steps", - type=int, - required=False, - help="Number of steps between logging. Default is 50.", - ) - parser.add_argument( - "--min-seq-length", - type=float_or_int_or_none, - required=False, - default=1024, - help="Minimum sequence length. Sampled will be padded if less than this value. Set 'None' to unset minimum.", - ) - parser.add_argument( - "--max-seq-length", - type=int, - required=False, - default=1024, - help="Maximum sequence length. Samples will be truncated if exceeds this value.", - ) - parser.add_argument( - "--limit-val-batches", - type=float_or_int_or_none, - required=False, - default=2, - help="Number of global batches used for validation if int. Fraction of validation dataset if float. Default is 2.", - ) - parser.add_argument( - "--micro-batch-size", - type=int, - required=False, - default=64, - help="Micro-batch size. Global batch size is inferred from this.", - ) - parser.add_argument( - "--pipeline-model-parallel-size", - type=int, - required=False, - default=1, - help="Pipeline model parallel size. Default is 1.", - ) - parser.add_argument( - "--tensor-model-parallel-size", - type=int, - required=False, - default=1, - help="Tensor model parallel size. Default is 1.", - ) - parser.add_argument( - "--accumulate-grad-batches", - type=int, - required=False, - default=1, - help="Gradient accumulation steps. Global batch size is inferred from this.", - ) - parser.add_argument( - "--save-last-checkpoint", - action="store_true", - default=True, - help="Save the last checkpoint.", - ) - parser.add_argument( - "--metric-to-monitor-for-checkpoints", - type=str, - required=False, - default="val_loss", - help="The metric to monitor for checkpointing.", - ) - parser.add_argument( - "--save-top-k", - type=int, - required=False, - default=2, - help="Save the top k checkpoints.", - ) - parser.add_argument( - "--restore-from-checkpoint-path", - type=Path, - required=False, - default=None, - help="Path to the checkpoint directory to restore from. Will override `--resume-if-exists` when set.", - ) - parser.add_argument( - "--nsys-profiling", - action="store_true", - default=False, - help="Enable targeted `nsys` profiling on the training loop for a defined step range. To actually get profiling output you must run the whole program with `nsys`. For example: " - " `nsys profile -s none -o output_report_name -t cuda,nvtx --force-overwrite true --capture-range=cudaProfilerApi --capture-range-end=stop [regular python command here]`", - ) - # start, end, rank - parser.add_argument( - "--nsys-start-step", - type=int, - required=False, - default=0, - help="Start nsys profiling after this step.", - ) - parser.add_argument( - "--nsys-end-step", - type=int, - required=False, - help="End nsys profiling after this step.", - ) - # rank as list of integers - parser.add_argument( - "--nsys-ranks", - type=int, - nargs="+", - required=False, - default=[0], - help="Enable nsys profiling for these ranks.", - ) - # DDP config - parser.add_argument( - "--no-overlap-grad-reduce", - action="store_true", - default=False, - ) - parser.add_argument( - "--no-overlap-param-gather", - action="store_true", - default=False, - ) - parser.add_argument( - "--no-average-in-collective", - action="store_true", - default=False, - ) - parser.add_argument( - "--grad-reduce-in-fp32", - action="store_true", - default=False, - ) - - config_class_options: Dict[str, Type[BioBertConfig]] = SUPPORTED_CONFIGS - - def config_class_type(desc: str) -> Type[BioBertConfig]: - try: - return config_class_options[desc] - except KeyError: - raise argparse.ArgumentTypeError( - f"Do not recognize key {desc}, valid options are: {config_class_options.keys()}" - ) - - parser.add_argument( - "--config-class", - type=config_class_type, - default=ESM2FineTuneSeqConfig, - help="Model configs link model classes with losses, and handle model initialization (including from a prior " - "checkpoint). This is how you can fine-tune a model. First train with one config class that points to one model " - "class and loss, then implement and provide an alternative config class that points to a variant of that model " - "and alternative loss. In the future this script should also provide similar support for picking different data " - f"modules for fine-tuning with different data types. Choices: {config_class_options.keys()}", - ) - - dataset_class_options: Dict[str, Type[InMemoryProteinDataset]] = SUPPORTED_DATASETS - - def dataset_class_type(desc: str) -> Type[InMemoryProteinDataset]: - try: - return dataset_class_options[desc] - except KeyError: - raise argparse.ArgumentTypeError( - f"Do not recognize key {desc}, valid options are: {dataset_class_options.keys()}" - ) - - parser.add_argument( - "--dataset-class", - type=dataset_class_type, - default=InMemorySingleValueDataset, - help=f"Dataset class name for finetuning. Choices: {config_class_options.keys()}", - ) - return parser - - -if __name__ == "__main__": - finetune_esm2_entrypoint() diff --git a/sub-packages/bionemo-esm2/src/bionemo/esm2/scripts/infer_esm2.py b/sub-packages/bionemo-esm2/src/bionemo/esm2/scripts/infer_esm2.py index 9531165c13..bdfaa4fe6b 100644 --- a/sub-packages/bionemo-esm2/src/bionemo/esm2/scripts/infer_esm2.py +++ b/sub-packages/bionemo-esm2/src/bionemo/esm2/scripts/infer_esm2.py @@ -23,8 +23,7 @@ from bionemo.core.utils.dtypes import PrecisionTypes, get_autocast_dtype from bionemo.esm2.api import ESM2Config from bionemo.esm2.data.tokenizer import get_tokenizer -from bionemo.esm2.model.finetune.datamodule import ESM2FineTuneDataModule -from bionemo.esm2.model.finetune.dataset import InMemoryProteinDataset +from bionemo.esm2.model.finetune.datamodule import ESM2FineTuneDataModule, InMemoryCSVDataset from bionemo.esm2.model.finetune.finetune_regressor import ESM2FineTuneSeqConfig from bionemo.esm2.model.finetune.finetune_token_classifier import ESM2FineTuneTokenConfig from bionemo.llm.model.biobert.lightning import biobert_lightning_module @@ -111,7 +110,7 @@ def infer_model( plugins=nl.MegatronMixedPrecision(precision=precision), ) - dataset = InMemoryProteinDataset.from_csv(data_path) + dataset = InMemoryCSVDataset(data_path=data_path) datamodule = ESM2FineTuneDataModule( predict_dataset=dataset, micro_batch_size=micro_batch_size, diff --git a/sub-packages/bionemo-esm2/tests/bionemo/esm2/conftest.py b/sub-packages/bionemo-esm2/tests/bionemo/esm2/conftest.py index fc896b54a2..2e0c7a0f7f 100644 --- a/sub-packages/bionemo-esm2/tests/bionemo/esm2/conftest.py +++ b/sub-packages/bionemo-esm2/tests/bionemo/esm2/conftest.py @@ -87,14 +87,3 @@ def dummy_data_single_value_regression_ft(dummy_data_per_token_classification_ft """ data = [(seq, len(seq) / 100.0) for seq, _ in dummy_data_per_token_classification_ft] return data - - -@pytest.fixture -def dummy_protein_sequences(dummy_data_per_token_classification_ft): - """Fixture providing dummy data for per-token classification fine-tuning. - - Returns: - list: A list of dummy data for per-token classification fine-tuning. - """ - data = [seq for seq, _ in dummy_data_per_token_classification_ft] - return data diff --git a/sub-packages/bionemo-esm2/tests/bionemo/esm2/model/finetune/test_datamodule.py b/sub-packages/bionemo-esm2/tests/bionemo/esm2/model/finetune/test_datamodule.py deleted file mode 100644 index c1ccfb5284..0000000000 --- a/sub-packages/bionemo-esm2/tests/bionemo/esm2/model/finetune/test_datamodule.py +++ /dev/null @@ -1,81 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: LicenseRef-Apache2 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -import pandas as pd -import pytest -from torch.utils.data import DataLoader - -from bionemo.esm2.model.finetune.datamodule import ESM2FineTuneDataModule -from bionemo.esm2.model.finetune.dataset import InMemoryProteinDataset - - -@pytest.fixture -def dummy_protein_csv(tmp_path, dummy_protein_sequences): - """Create a mock protein dataset.""" - csv_file = tmp_path / "protein_dataset.csv" - # Create a DataFrame - df = pd.DataFrame(dummy_protein_sequences, columns=["sequences"]) - - # Save the DataFrame to a CSV file - df.to_csv(csv_file, index=False) - return csv_file - - -@pytest.fixture -def dataset(dummy_protein_csv): - return InMemoryProteinDataset.from_csv(dummy_protein_csv) - - -@pytest.fixture -def data_module(dataset): - return ESM2FineTuneDataModule(predict_dataset=dataset) - - -def test_in_memory_csv_dataset(dataset): - assert len(dataset) > 0 - sample = dataset[0] - assert isinstance(sample, dict) - assert "text" in sample - assert "labels" in sample - - -def test_esm2_fine_tune_data_module_init(data_module): - assert data_module.train_dataset is None - assert data_module.valid_dataset is None - assert data_module.predict_dataset is not None - - -def test_esm2_fine_tune_data_module_predict_dataloader(data_module): - predict_dataloader = data_module.predict_dataloader() - assert isinstance(predict_dataloader, DataLoader) - batch = next(iter(predict_dataloader)) - assert isinstance(batch, dict) - assert "text" in batch - - -def test_esm2_fine_tune_data_module_setup(data_module): - with pytest.raises(RuntimeError): - data_module.setup("fit") - - -def test_esm2_fine_tune_data_module_train_dataloader(data_module): - with pytest.raises(AttributeError): - data_module.train_dataloader() - - -def test_esm2_fine_tune_data_module_val_dataloader(data_module): - with pytest.raises(AttributeError): - data_module.val_dataloader() diff --git a/sub-packages/bionemo-esm2/tests/bionemo/esm2/model/finetune/test_dataset.py b/sub-packages/bionemo-esm2/tests/bionemo/esm2/model/finetune/test_dataset.py deleted file mode 100644 index afcd53feab..0000000000 --- a/sub-packages/bionemo-esm2/tests/bionemo/esm2/model/finetune/test_dataset.py +++ /dev/null @@ -1,168 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: LicenseRef-Apache2 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -import pandas as pd -import pytest -import torch -from torch import Tensor - -from bionemo.esm2.model.finetune.dataset import ( - InMemoryPerTokenValueDataset, - InMemoryProteinDataset, - InMemorySingleValueDataset, -) -from bionemo.llm.data.collate import MLM_LOSS_IGNORE_INDEX -from bionemo.llm.data.label2id_tokenizer import Label2IDTokenizer - - -def data_to_csv(data, tmp_path, with_label=True): - """Create a mock protein dataset.""" - csv_file = tmp_path / "protein_dataset.csv" - # Create a DataFrame - df = pd.DataFrame(data, columns=["sequences", "labels"] if with_label else ["sequences"]) - - # Save the DataFrame to a CSV file - df.to_csv(csv_file, index=False) - return csv_file - - -@pytest.fixture -def dataset_no_labels(dummy_protein_sequences, tmp_path): - csv_path = data_to_csv(dummy_protein_sequences, tmp_path, with_label=False) - return InMemoryProteinDataset.from_csv(csv_path) - - -@pytest.fixture -def dataset_regression_labels(dummy_data_single_value_regression_ft, tmp_path): - csv_path = data_to_csv(dummy_data_single_value_regression_ft, tmp_path, with_label=True) - return InMemorySingleValueDataset.from_csv(csv_path) - - -@pytest.fixture -def dataset_per_token_classification_labels(dummy_data_per_token_classification_ft, tmp_path): - csv_path = data_to_csv(dummy_data_per_token_classification_ft, tmp_path, with_label=True) - return InMemoryPerTokenValueDataset.from_csv(csv_path) - - -def test_in_memory_protein_dataset_length_no_labels(dataset_no_labels, dummy_protein_sequences): - assert len(dataset_no_labels) == len(dummy_protein_sequences) - - -def test_in_memory_protein_dataset_length_with_regression_labels( - dataset_regression_labels, dummy_data_single_value_regression_ft -): - assert len(dataset_regression_labels) == len(dummy_data_single_value_regression_ft) - - -def test_in_memory_protein_dataset_length_with_class_labels( - dataset_per_token_classification_labels, dummy_data_per_token_classification_ft -): - assert len(dataset_per_token_classification_labels) == len(dummy_data_per_token_classification_ft) - - -def test_in_memory_protein_dataset_getitem_no_labels(dataset_no_labels): - sample = dataset_no_labels[0] - assert isinstance(sample, dict) - assert "text" in sample - assert "labels" in sample - assert isinstance(sample["text"], Tensor) - assert isinstance(sample["labels"], Tensor) - - -def test_in_memory_protein_dataset_getitem_with_regression_labels(dataset_regression_labels): - assert isinstance(dataset_regression_labels, InMemoryProteinDataset) - sample = dataset_regression_labels[0] - assert isinstance(sample, dict) - assert "text" in sample - assert "labels" in sample - assert isinstance(sample["text"], Tensor) - assert isinstance(sample["labels"], Tensor) - assert sample["labels"].dtype == torch.float - - -def test_in_memory_protein_dataset_getitem_with_class_labels(dataset_per_token_classification_labels): - assert isinstance(dataset_per_token_classification_labels, InMemoryProteinDataset) - assert isinstance(dataset_per_token_classification_labels.label_tokenizer, Label2IDTokenizer) - assert dataset_per_token_classification_labels.label_cls_eos_id == MLM_LOSS_IGNORE_INDEX - - sample = dataset_per_token_classification_labels[0] - assert isinstance(sample, dict) - assert "text" in sample - assert "labels" in sample - assert isinstance(sample["text"], Tensor) - assert isinstance(sample["labels"], Tensor) - assert sample["labels"].dtype == torch.int64 - - -def test_in_memory_protein_dataset_tokenization(dataset_no_labels): - sample = dataset_no_labels[0] - tokenized_sequence = sample["text"] - assert isinstance(tokenized_sequence, Tensor) - assert tokenized_sequence.ndim == 1 # Ensure it's flattened. - - -def test_transofrm_classification_label( - dataset_per_token_classification_labels, dummy_data_per_token_classification_ft -): - pre_transfrom = dummy_data_per_token_classification_ft[0][1] - label_ids = torch.tensor(dataset_per_token_classification_labels.label_tokenizer.text_to_ids(pre_transfrom)) - cls_eos = torch.tensor([dataset_per_token_classification_labels.label_cls_eos_id]) - post_transform = torch.cat((cls_eos, label_ids, cls_eos)) - - assert torch.equal(dataset_per_token_classification_labels.transform_label(pre_transfrom), post_transform) - - -def test_transofrm_regression_label(dataset_regression_labels): - """Ensure labels are transformed correctly.""" - transformed_label = dataset_regression_labels.transform_label(1.0) - assert isinstance(transformed_label, Tensor) - assert transformed_label.dtype == torch.float - - -def test_in_memory_protein_dataset_no_labels_fallback(dataset_no_labels): - """Ensure the dataset works even when labels are missing.""" - sample = dataset_no_labels[0] - assert isinstance(sample, dict) - assert "labels" in sample - assert isinstance(sample["labels"], Tensor) - - -def test_in_memory_protein_dataset_invalid_index(dataset_no_labels): - """Test if out-of-range index raises an error.""" - with pytest.raises(KeyError): - _ = dataset_no_labels[100] - - -def test_in_memory_protein_dataset_missing_sequences_column(tmp_path): - """Test behavior when the CSV file is empty.""" - csv_file = tmp_path / "invalid.csv" - pd.DataFrame({"wrong_column": ["MKTFFS"]}).to_csv(csv_file, index=False) - with pytest.raises(KeyError): - _ = InMemoryProteinDataset.from_csv(csv_file) - - -def test_in_memory_protein_dataset_special_tokens_masking(dataset_no_labels): - """Ensure loss mask correctly handles special tokens.""" - sample = dataset_no_labels[0] - assert "loss_mask" in sample - assert isinstance(sample["loss_mask"], Tensor) - assert sample["loss_mask"].dtype == torch.bool - - -def test_in_memory_protein_dataset_non_existent_file(): - """Ensure proper error handling for missing files.""" - with pytest.raises(FileNotFoundError): - InMemoryProteinDataset.from_csv("non_existent_file.csv") diff --git a/sub-packages/bionemo-esm2/tests/bionemo/esm2/model/finetune/test_finetune.py b/sub-packages/bionemo-esm2/tests/bionemo/esm2/model/finetune/test_finetune.py new file mode 100644 index 0000000000..04dfd6a443 --- /dev/null +++ b/sub-packages/bionemo-esm2/tests/bionemo/esm2/model/finetune/test_finetune.py @@ -0,0 +1,143 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import pytest +from nemo.lightning import io + +from bionemo.core.data.load import load +from bionemo.esm2.model.finetune.datamodule import ESM2FineTuneDataModule +from bionemo.esm2.model.finetune.finetune_regressor import ( + ESM2FineTuneSeqConfig, + InMemorySingleValueDataset, +) +from bionemo.esm2.model.finetune.finetune_token_classifier import ( + ESM2FineTuneTokenConfig, + InMemoryPerTokenValueDataset, +) +from bionemo.esm2.model.finetune.peft import ESM2LoRA +from bionemo.esm2.model.finetune.train import train_model +from bionemo.testing import megatron_parallel_state_utils +from bionemo.testing.callbacks import MetricTracker + + +# To download a 8M internally pre-trained ESM2 model +pretrain_ckpt_path = load("esm2/8m:2.0") + + +@pytest.mark.needs_gpu +@pytest.mark.parametrize("with_peft", [True, False]) +def test_esm2_finetune_token_classifier( + tmp_path, + tokenizer, + dummy_data_per_token_classification_ft, + with_peft: bool, + n_steps_train: int = 50, + seed: int = 42, +): + if with_peft: + pytest.xfail("FIXME PEFT fine-tuning not supported with fusions active") + + with megatron_parallel_state_utils.distributed_model_parallel_state(seed): + if with_peft: + peft = ESM2LoRA() + else: + peft = None + esm2_finetune_config = ESM2FineTuneTokenConfig(initial_ckpt_path=str(pretrain_ckpt_path)) + dataset = InMemoryPerTokenValueDataset(dummy_data_per_token_classification_ft, seed=seed) + finetune_data_module = ESM2FineTuneDataModule(dataset, dataset) + simple_ft_checkpoint, simple_ft_metrics, trainer = train_model( + experiment_name="finetune_new_head", + experiment_dir=tmp_path / "finetune_new_head", # new checkpoint will land in a subdir of this + config=esm2_finetune_config, # same config as before since we are just continuing training + data_module=finetune_data_module, + n_steps_train=n_steps_train, + metric_tracker=MetricTracker(metrics_to_track_val=["loss"], metrics_to_track_train=["loss"]), + tokenizer=tokenizer, + peft=peft, + _use_rich_model_summary=False, + ) + + weights_ckpt = simple_ft_checkpoint / "weights" + assert weights_ckpt.exists() + assert weights_ckpt.is_dir() + assert io.is_distributed_ckpt(weights_ckpt) + assert simple_ft_metrics.collection_train["loss"][0] > simple_ft_metrics.collection_train["loss"][-1] + + if with_peft: + assert trainer.model.model_transform is not None + model = trainer.model[0].module.module.module + assert all(not p.requires_grad for p in model.embedding.parameters()) + assert all(not p.requires_grad for name, p in model.encoder.named_parameters() if "adapter" not in name) + assert all(p.requires_grad for name, p in model.encoder.named_parameters() if "adapter" in name) + assert all(p.requires_grad for p in model.classification_head.parameters()) + else: + encoder_requires_grad = [ + p.requires_grad for name, p in trainer.model.named_parameters() if "classification_head" not in name + ] + assert not all(encoder_requires_grad), "Pretrained model is not fully frozen during fine-tuning" + + +@pytest.mark.needs_gpu +@pytest.mark.parametrize("with_peft", [True, False]) +def test_esm2_finetune_regressor( + tmp_path, + tokenizer, + dummy_data_single_value_regression_ft, + with_peft: bool, + n_steps_train: int = 50, + seed: int = 42, +): + if with_peft: + pytest.xfail("FIXME PEFT fine-tuning not supported") + + with megatron_parallel_state_utils.distributed_model_parallel_state(seed): + if with_peft: + peft = ESM2LoRA() + else: + peft = None + esm2_regression_finetune_config = ESM2FineTuneSeqConfig(initial_ckpt_path=str(pretrain_ckpt_path)) + dataset = InMemorySingleValueDataset(dummy_data_single_value_regression_ft, seed=seed) + finetune_data_module = ESM2FineTuneDataModule(dataset, dataset) + simple_ft_checkpoint, simple_ft_metrics, trainer = train_model( + experiment_name="finetune_new_head_regression", + experiment_dir=tmp_path / "finetune_new_head_regression", # new checkpoint will land in a subdir of this + config=esm2_regression_finetune_config, # same config as before since we are just continuing training + data_module=finetune_data_module, + n_steps_train=n_steps_train, + metric_tracker=MetricTracker(metrics_to_track_val=["loss"], metrics_to_track_train=["loss"]), + tokenizer=tokenizer, + peft=peft, + _use_rich_model_summary=False, + ) + + weights_ckpt = simple_ft_checkpoint / "weights" + assert weights_ckpt.exists() + assert weights_ckpt.is_dir() + assert io.is_distributed_ckpt(weights_ckpt) + assert simple_ft_metrics.collection_train["loss"][0] > simple_ft_metrics.collection_train["loss"][-1] + + if with_peft: + assert trainer.model.model_transform is not None + model = trainer.model[0].module.module.module + assert all(not p.requires_grad for p in model.embedding.parameters()) + assert all(not p.requires_grad for name, p in model.encoder.named_parameters() if "adapter" not in name) + assert all(p.requires_grad for name, p in model.encoder.named_parameters() if "adapter" in name) + assert all(p.requires_grad for p in model.regression_head.parameters()) + else: + encoder_requires_grad = [ + p.requires_grad for name, p in trainer.model.named_parameters() if "regression_head" not in name + ] + assert not all(encoder_requires_grad), "Pretrained model is not fully frozen during fine-tuning" diff --git a/sub-packages/bionemo-esm2/tests/bionemo/esm2/scripts/test_finetune_esm2.py b/sub-packages/bionemo-esm2/tests/bionemo/esm2/scripts/test_finetune_esm2.py deleted file mode 100644 index c51af07d5a..0000000000 --- a/sub-packages/bionemo-esm2/tests/bionemo/esm2/scripts/test_finetune_esm2.py +++ /dev/null @@ -1,311 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: LicenseRef-Apache2 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -from pathlib import Path -from unittest.mock import patch - -import pandas as pd -import pytest -from nemo.lightning import io - -from bionemo.core.data.load import load -from bionemo.esm2.model.finetune.dataset import InMemoryPerTokenValueDataset, InMemorySingleValueDataset -from bionemo.esm2.model.finetune.finetune_regressor import ESM2FineTuneSeqConfig -from bionemo.esm2.model.finetune.finetune_token_classifier import ESM2FineTuneTokenConfig -from bionemo.esm2.scripts.finetune_esm2 import finetune_esm2_entrypoint, get_parser, train_model -from bionemo.testing import megatron_parallel_state_utils -from bionemo.testing.callbacks import MetricTracker - - -def data_to_csv(data, tmp_path): - """Create a mock protein dataset.""" - csv_file = tmp_path / "protein_dataset.csv" - # Create a DataFrame - df = pd.DataFrame(data, columns=["sequences", "labels"]) - - # Save the DataFrame to a CSV file - df.to_csv(csv_file, index=False) - return csv_file - - -@pytest.mark.needs_gpu -def test_esm2_finetune_token_classifier( - tmp_path, - dummy_data_per_token_classification_ft, - n_steps_train: int = 50, - seed: int = 42, -): - with megatron_parallel_state_utils.distributed_model_parallel_state(seed): - simple_ft_checkpoint, simple_ft_metrics, trainer = train_model( - train_data_path=data_to_csv(dummy_data_per_token_classification_ft, tmp_path), - valid_data_path=data_to_csv(dummy_data_per_token_classification_ft, tmp_path), - experiment_name="finetune_new_head_token_classification", - restore_from_checkpoint_path=str(load("esm2/8m:2.0")), - num_steps=n_steps_train, - num_nodes=1, - devices=1, - min_seq_length=None, - max_seq_length=1024, - result_dir=tmp_path / "finetune", - limit_val_batches=2, - val_check_interval=n_steps_train // 2, - log_every_n_steps=n_steps_train // 2, - num_dataset_workers=10, - lr=1e-5, - micro_batch_size=4, - accumulate_grad_batches=1, - resume_if_exists=False, - precision="bf16-mixed", - dataset_class=InMemoryPerTokenValueDataset, - config_class=ESM2FineTuneTokenConfig, - metric_tracker=MetricTracker(metrics_to_track_val=["loss"], metrics_to_track_train=["loss"]), - ) - - weights_ckpt = simple_ft_checkpoint / "weights" - assert weights_ckpt.exists() - assert weights_ckpt.is_dir() - assert io.is_distributed_ckpt(weights_ckpt) - assert simple_ft_metrics.collection_train["loss"][0] > simple_ft_metrics.collection_train["loss"][-1] - - encoder_requires_grad = [ - p.requires_grad for name, p in trainer.model.named_parameters() if "classification_head" not in name - ] - assert not all(encoder_requires_grad), "Pretrained model is not fully frozen during fine-tuning" - - -@pytest.mark.needs_gpu -def test_esm2_finetune_regressor( - tmp_path, - dummy_data_single_value_regression_ft, - n_steps_train: int = 50, - seed: int = 42, -): - with megatron_parallel_state_utils.distributed_model_parallel_state(seed): - simple_ft_checkpoint, simple_ft_metrics, trainer = train_model( - train_data_path=data_to_csv(dummy_data_single_value_regression_ft, tmp_path), - valid_data_path=data_to_csv(dummy_data_single_value_regression_ft, tmp_path), - experiment_name="finetune_new_head_regression", - restore_from_checkpoint_path=str(load("esm2/8m:2.0")), - num_steps=n_steps_train, - num_nodes=1, - devices=1, - min_seq_length=None, - max_seq_length=1024, - result_dir=tmp_path / "finetune", - limit_val_batches=2, - val_check_interval=n_steps_train // 2, - log_every_n_steps=n_steps_train // 2, - num_dataset_workers=10, - lr=1e-5, - micro_batch_size=4, - accumulate_grad_batches=1, - resume_if_exists=False, - precision="bf16-mixed", - dataset_class=InMemorySingleValueDataset, - config_class=ESM2FineTuneSeqConfig, - metric_tracker=MetricTracker(metrics_to_track_val=["loss"], metrics_to_track_train=["loss"]), - ) - - weights_ckpt = simple_ft_checkpoint / "weights" - assert weights_ckpt.exists() - assert weights_ckpt.is_dir() - assert io.is_distributed_ckpt(weights_ckpt) - assert simple_ft_metrics.collection_train["loss"][0] > simple_ft_metrics.collection_train["loss"][-1] - - encoder_requires_grad = [ - p.requires_grad for name, p in trainer.model.named_parameters() if "regression_head" not in name - ] - assert not all(encoder_requires_grad), "Pretrained model is not fully frozen during fine-tuning" - - -@pytest.fixture -def mock_train_model(): - with patch("bionemo.esm2.scripts.finetune_esm2.train_model") as mock_train: - yield mock_train - - -@pytest.fixture -def mock_parser_args(): - """Fixture to create mock arguments for the parser.""" - return [ - "--train-data-path", - str(Path("train.csv")), - "--valid-data-path", - str(Path("valid.csv")), - "--num-gpus", - "1", - "--num-nodes", - "1", - "--min-seq-length", - "512", - "--max-seq-length", - "1024", - "--result-dir", - str(Path("./results")), - "--lr", - "0.001", - ] - - -def test_finetune_esm2_entrypoint(mock_train_model, mock_parser_args): - """Test the finetune_esm2_entrypoint function with mocked arguments.""" - with patch("sys.argv", ["finetune_esm2_entrypoint.py"] + mock_parser_args): - finetune_esm2_entrypoint() - - # Check if train_model was called once - mock_train_model.assert_called_once() - - # Check if the arguments were passed correctly - called_kwargs = mock_train_model.call_args.kwargs - assert called_kwargs["train_data_path"] == Path("train.csv") - assert called_kwargs["valid_data_path"] == Path("valid.csv") - assert called_kwargs["devices"] == 1 - assert called_kwargs["num_nodes"] == 1 - assert called_kwargs["min_seq_length"] == 512 - assert called_kwargs["max_seq_length"] == 1024 - assert called_kwargs["lr"] == 0.001 - assert called_kwargs["result_dir"] == Path("./results") - - -def test_get_parser(): - """Test the argument parser with all possible arguments.""" - parser = get_parser() - args = parser.parse_args( - [ - "--train-data-path", - "train.csv", - "--valid-data-path", - "valid.csv", - "--precision", - "bf16-mixed", - "--lr", - "0.001", - "--create-tensorboard-logger", - "--resume-if-exists", - "--result-dir", - "./results", - "--experiment-name", - "esm2_experiment", - "--wandb-entity", - "my_team", - "--wandb-project", - "ft_project", - "--wandb-tags", - "tag1", - "tag2", - "--wandb-group", - "group1", - "--wandb-id", - "1234", - "--wandb-anonymous", - "--wandb-log-model", - "--wandb-offline", - "--num-gpus", - "2", - "--num-nodes", - "1", - "--num-steps", - "1000", - "--num-dataset-workers", - "4", - "--val-check-interval", - "500", - "--log-every-n-steps", - "100", - "--min-seq-length", - "512", - "--max-seq-length", - "1024", - "--limit-val-batches", - "2", - "--micro-batch-size", - "32", - "--pipeline-model-parallel-size", - "2", - "--tensor-model-parallel-size", - "2", - "--accumulate-grad-batches", - "2", - "--save-last-checkpoint", - "--metric-to-monitor-for-checkpoints", - "val_loss", - "--save-top-k", - "5", - "--restore-from-checkpoint-path", - "./checkpoint", - "--nsys-profiling", - "--nsys-start-step", - "10", - "--nsys-end-step", - "50", - "--nsys-ranks", - "0", - "1", - "--no-overlap-grad-reduce", - "--no-overlap-param-gather", - "--no-average-in-collective", - "--grad-reduce-in-fp32", - "--dataset-class", - "InMemoryPerTokenValueDataset", - "--config-class", - "ESM2FineTuneTokenConfig", - ] - ) - - # Assertions for all arguments - assert args.train_data_path == Path("train.csv") - assert args.valid_data_path == Path("valid.csv") - assert args.precision == "bf16-mixed" - assert args.lr == 0.001 - assert args.create_tensorboard_logger is True - assert args.resume_if_exists is True - assert args.result_dir == Path("./results") - assert args.experiment_name == "esm2_experiment" - assert args.wandb_entity == "my_team" - assert args.wandb_project == "ft_project" - assert args.wandb_tags == ["tag1", "tag2"] - assert args.wandb_group == "group1" - assert args.wandb_id == "1234" - assert args.wandb_anonymous is True - assert args.wandb_log_model is True - assert args.wandb_offline is True - assert args.num_gpus == 2 - assert args.num_nodes == 1 - assert args.num_steps == 1000 - assert args.num_dataset_workers == 4 - assert args.val_check_interval == 500 - assert args.log_every_n_steps == 100 - assert args.min_seq_length == 512 - assert args.max_seq_length == 1024 - assert args.limit_val_batches == 2 - assert args.micro_batch_size == 32 - assert args.pipeline_model_parallel_size == 2 - assert args.tensor_model_parallel_size == 2 - assert args.accumulate_grad_batches == 2 - assert args.save_last_checkpoint is True - assert args.metric_to_monitor_for_checkpoints == "val_loss" - assert args.save_top_k == 5 - assert args.restore_from_checkpoint_path == Path("./checkpoint") - assert args.nsys_profiling is True - assert args.nsys_start_step == 10 - assert args.nsys_end_step == 50 - assert args.nsys_ranks == [0, 1] - assert args.no_overlap_grad_reduce is True - assert args.no_overlap_param_gather is True - assert args.no_average_in_collective is True - assert args.grad_reduce_in_fp32 is True - assert args.dataset_class == InMemoryPerTokenValueDataset - assert args.config_class == ESM2FineTuneTokenConfig diff --git a/sub-packages/bionemo-esm2/tests/bionemo/esm2/scripts/test_infer_esm2.py b/sub-packages/bionemo-esm2/tests/bionemo/esm2/scripts/test_infer_esm2.py index b3c349b8f3..bf080da23a 100644 --- a/sub-packages/bionemo-esm2/tests/bionemo/esm2/scripts/test_infer_esm2.py +++ b/sub-packages/bionemo-esm2/tests/bionemo/esm2/scripts/test_infer_esm2.py @@ -19,17 +19,45 @@ import pandas as pd import pytest import torch +from torch.utils.data import DataLoader from bionemo.core.data.load import load from bionemo.core.utils.dtypes import get_autocast_dtype from bionemo.esm2.api import ESM2Config from bionemo.esm2.data.tokenizer import get_tokenizer +from bionemo.esm2.model.finetune.datamodule import ESM2FineTuneDataModule, InMemoryCSVDataset from bionemo.esm2.scripts.infer_esm2 import infer_model from bionemo.llm.data import collate from bionemo.llm.lightning import batch_collator from bionemo.llm.utils.callbacks import IntervalT +# Function to check GPU memory +def check_gpu_memory(threshold_gb): + if torch.cuda.is_available(): + gpu_memory = torch.cuda.get_device_properties(0).total_memory / (1024**3) # Memory in GB + return gpu_memory < threshold_gb + return False + + +@pytest.fixture +def dummy_protein_sequences(): + """Create a list of artificial protein sequences""" + artificial_sequence_data = [ + "TLILGWSDKLGSLLNQLAIANESLGGGTIAVMAERDKEDMELDIGKMEFDFKGTSVI", + "LYSGDHSTQGARFLRDLAENTGRAEYELLSLF", + "GRFNVWLGGNESKIRQVLKAVKEIGVSPTLFAVYEKN", + "DELTALGGLLHDIGKPVQRAGLYSGDHSTQGARFLRDLAENTGRAEYELLSLF", + "KLGSLLNQLAIANESLGGGTIAVMAERDKEDMELDIGKMEFDFKGTSVI", + "LFGAIGNAISAIHGQSAVEELVDAFVGGARISSAFPYSGDTYYLPKP", + "LGGLLHDIGKPVQRAGLYSGDHSTQGARFLRDLAENTGRAEYELLSLF", + "LYSGDHSTQGARFLRDLAENTGRAEYELLSLF", + "ISAIHGQSAVEELVDAFVGGARISSAFPYSGDTYYLPKP", + "SGSKASSDSQDANQCCTSCEDNAPATSYCVECSEPLCETCVEAHQRVKYTKDHTVRSTGPAKT", + ] + return artificial_sequence_data + + @pytest.fixture def dummy_protein_csv(tmp_path, dummy_protein_sequences): """Create a mock protein dataset.""" @@ -42,6 +70,16 @@ def dummy_protein_csv(tmp_path, dummy_protein_sequences): return csv_file +@pytest.fixture +def dataset(dummy_protein_csv): + return InMemoryCSVDataset(dummy_protein_csv) + + +@pytest.fixture +def data_module(dataset): + return ESM2FineTuneDataModule(predict_dataset=dataset) + + @pytest.fixture def padded_tokenized_sequences(dummy_protein_sequences): tokenizer = get_tokenizer() @@ -53,6 +91,49 @@ def padded_tokenized_sequences(dummy_protein_sequences): return collated_batch["text"] +def test_in_memory_csv_dataset(dataset): + assert len(dataset) > 0 + sample = dataset[0] + assert isinstance(sample, dict) + assert "text" in sample + assert "labels" in sample + + +def test_in_memory_csv_dataset_load_data(dataset, dummy_protein_csv): + sequences, labels = dataset.load_data(dummy_protein_csv) + assert isinstance(sequences, list) + assert isinstance(labels, list) + + +def test_esm2_fine_tune_data_module_init(data_module): + assert data_module.train_dataset is None + assert data_module.valid_dataset is None + assert data_module.predict_dataset is not None + + +def test_esm2_fine_tune_data_module_predict_dataloader(data_module): + predict_dataloader = data_module.predict_dataloader() + assert isinstance(predict_dataloader, DataLoader) + batch = next(iter(predict_dataloader)) + assert isinstance(batch, dict) + assert "text" in batch + + +def test_esm2_fine_tune_data_module_setup(data_module): + with pytest.raises(RuntimeError): + data_module.setup("fit") + + +def test_esm2_fine_tune_data_module_train_dataloader(data_module): + with pytest.raises(AttributeError): + data_module.train_dataloader() + + +def test_esm2_fine_tune_data_module_val_dataloader(data_module): + with pytest.raises(AttributeError): + data_module.val_dataloader() + + @pytest.mark.parametrize("precision", ["fp32", "bf16-mixed"]) @pytest.mark.parametrize("prediction_interval", get_args(IntervalT)) def test_infer_runs(