diff --git a/docs/docs/user-guide/examples/bionemo-esm2/finetune.ipynb b/docs/docs/user-guide/examples/bionemo-esm2/finetune.ipynb
deleted file mode 100644
index aa26cd72ff..0000000000
--- a/docs/docs/user-guide/examples/bionemo-esm2/finetune.ipynb
+++ /dev/null
@@ -1,898 +0,0 @@
-{
- "cells": [
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "[![ Click here to deploy.](https://uohmivykqgnnbiouffke.supabase.co/storage/v1/object/public/landingpage/brevdeploynavy.svg)](https://console.brev.dev/launchable/deploy?launchableID=env-2rPWpPzzJIxq7SMRJIQehCxBymV)\n",
- "\n",
- "
NOTE It takes about 10 minutes to deploy this notebook as a Launchable. As of this writing, we are working on a free tier so a credit card may be required. You can reach out to your NVIDIA rep for credits.
"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "# ESM-2 Fine-tuning"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "vscode": {
- "languageId": "plaintext"
- }
- },
- "source": [
- "The [ESM-2](https://www.science.org/doi/abs/10.1126/science.ade2574) model is a transformer-based protein language model that has achieved state-of-the-art results in various protein-related tasks. When fine-tuning ESM2, the task-head plays a crucial role. A task head refers to the additional layer or set of layers added on top of a pre-trained model, like the ESM-2 transformer-based protein language model, to adapt it for a specific downstream task. As a part of transfer learning, a pre-trained model is often utilized to learn generic features from a large-scale dataset. However, these features might not be directly applicable to the specific task at hand. By incorporating a task head, which consists of learnable parameters, the model can adapt and specialize to the target task. The task head serves as a flexible and adaptable component that learns task-specific representations by leveraging the pre-trained features as a foundation. Through fine-tuning, the task head enables the model to learn and extract task-specific patterns, improving performance and addressing the nuances of the downstream task. It acts as a critical bridge between the pre-trained model and the specific task, enabling efficient and effective transfer of knowledge."
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- " NOTE This tutorial will guide you through the steps for creating a basic regression fine-tuning task for simplicity. The utilities described in this tutorial are available in:\n",
- "\n",
- "
bionemo.esm2.model.finetune.finetune_regressor
\n",
- "\n",
- "The techniques demonstrated here can be adapted for classification and per-token classification tasks. Utilities needed for secondary structure prediction (token-level classification) are available in \n",
- "\n",
- "
bionemo.esm2.model.finetune.finetune_token_classifier
\n",
- "\n",
- "In the second part of the tutorial, we will cover loading a pre-trained model, fine-tuning it for both regression and per-token classification tasks, and using the fine-tuned models for inference. For instructions on pre-training the ESM-2 model, please refer to the [ESM-2 Pretraining](./pretrain.md) tutorial.
"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "## Building a Regression Fine-tune Module"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "Wwe need to define some key classes to successfully build a fine-tuning module in BioNeMo framework: \n",
- "\n",
- "1. **Loss Reduction Class** - To compute the supervised fine-tuning loss.\n",
- "2. **Fine-Tuned Model Head** - Downstream task head model.\n",
- "3. **Fine-Tuned Model** - Model that combines ESM-2 with the task head model.\n",
- "4. **Fine-Tuning Config** - Configures the fine-tuning model and loss to use in the training and inference framework.\n",
- "5. **Dataset** - Training and inference datasets for ESM2 fine-tuning."
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "### 1 - Loss Reduction Class"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "A class for calculating the supervised loss of the fine-tune model from targets. We inherit from Megatron Bert Masked Language Model Loss (`BERTMLMLossWithReduction`) and override the `forward()` pass to compute MSE loss of the regression head within a micro-batch. The `reduce()` method is used for computing the average over the micro-batches and is only used for logging.\n",
- "\n",
- "```python\n",
- "class RegressorLossReduction(BERTMLMLossWithReduction):\n",
- " def forward(\n",
- " self, batch: Dict[str, torch.Tensor], forward_out: Dict[str, torch.Tensor]\n",
- " ) -> Tuple[torch.Tensor, Union[PerTokenLossDict, SameSizeLossDict]]:\n",
- "\n",
- " regression_output = forward_out[\"regression_output\"]\n",
- " targets = batch[\"labels\"].to(dtype=regression_output.dtype) # [b, 1]\n",
- "\n",
- " loss = torch.nn.functional.mse_loss(regression_output, targets)\n",
- " return loss, {\"avg\": loss}\n",
- "\n",
- " def reduce(self, losses_reduced_per_micro_batch: Sequence[ReductionT]) -> torch.Tensor:\n",
- " losses = torch.stack([loss[\"avg\"] for loss in losses_reduced_per_micro_batch])\n",
- " return losses.mean()\n",
- "```"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "### 2 - Fine-Tuned Model Head"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "An MLP class for sequence-level regression. This class inherits `MegatronModule` and uses the fine-tune config (`TransformerConfig`) to configure the regression head for the fine-tuned ESM-2 model.\n",
- "\n",
- "```python\n",
- "class MegatronMLPHead(MegatronModule):\n",
- " def __init__(self, config: TransformerConfig):\n",
- " super().__init__(config)\n",
- " layer_sizes = [config.hidden_size, 256, 1]\n",
- " self.linear_layers = torch.nn.ModuleList(\n",
- " [torch.nn.Linear(i, o) for i, o in zip(layer_sizes[:-1], layer_sizes[1:])]\n",
- " )\n",
- " self.act = torch.nn.ReLU()\n",
- " self.dropout = torch.nn.Dropout(p=config.ft_dropout)\n",
- "\n",
- " def forward(self, hidden_states: torch.Tensor) -> List[torch.Tensor]:\n",
- " ...\n",
- "```"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "### 3 - Fine-Tuned Model"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "A fine-tuned ESM-2 model class for token classification tasks. This class inherits from the `ESM2Model` class and adds the custom regression head `MegatronMLPHead` the we created in the previous step. Optionally one can freeze all or parts of the encoder by parsing through the model parameters in the model constructor.\n",
- "\n",
- "```python\n",
- "class ESM2FineTuneSeqModel(ESM2Model):\n",
- " def __init__(self, config, *args, post_process: bool = True, include_embeddings: bool = False, **kwargs):\n",
- " super().__init__(config, *args, post_process=post_process, include_embeddings=True, **kwargs)\n",
- "\n",
- " # freeze encoder parameters\n",
- " if config.encoder_frozen:\n",
- " for _, param in self.named_parameters():\n",
- " param.requires_grad = False\n",
- "\n",
- " if post_process:\n",
- " self.regression_head = MegatronMLPHead(config)\n",
- "\n",
- " def forward(self, *args, **kwargs,):\n",
- " output = super().forward(*args, **kwargs)\n",
- " ...\n",
- " output[\"regression_output\"] = self.regression_head(embeddings)\n",
- " return output\n",
- "```"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "### 4 - Fine-Tuning Config"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "A `dataclass` that configures the fine-tuned ESM-2 model. In this example `ESM2FineTuneSeqConfig` inherits from `ESM2GenericConfig` and adds custom arguments to setup the fine-tuned model. The `configure_model()` method of this `dataclass` is called within the `Lightning` module to call the model constructor with the `dataclass` arguments.\n",
- "\n",
- "The common arguments among different fine-tuning tasks are\n",
- "\n",
- "- `model_cls`: The fine-tune model class defined in previous step (`ESM2FineTuneSeqModel`)\n",
- "- `initial_ckpt_path`: BioNeMo 2.0 ESM-2 pre-trained checkpoint\n",
- "- `initial_ckpt_skip_keys_with_these_prefixes`: skips keys when loading parameters from a checkpoint. For example, we should not look for `regression_head` in the pre-trained checkpoint.\n",
- "- `get_loss_reduction_class()`: Implements selection of the appropriate `MegatronLossReduction` class that we defined in the first step of this tutorial.\n",
- "\n",
- "```python\n",
- "\n",
- "@dataclass\n",
- "class ESM2FineTuneSeqConfig(\n",
- " ESM2GenericConfig[ESM2FineTuneSeqModel, RegressorLossReduction], iom.IOMixinWithGettersSetters\n",
- "):\n",
- " model_cls: Type[ESM2FineTuneSeqModel] = ESM2FineTuneSeqModel\n",
- " # The following checkpoint path is for nemo2 checkpoints. Config parameters not present in\n",
- " # self.override_parent_fields will be loaded from the checkpoint and override those values here.\n",
- " initial_ckpt_path: str | None = None\n",
- " # typical case is fine-tune the base biobert that doesn't have this head. If you are instead loading a checkpoint\n",
- " # that has this new head and want to keep using these weights, please drop this next line or set to []\n",
- " initial_ckpt_skip_keys_with_these_prefixes: List[str] = field(default_factory=lambda: [\"regression_head\"])\n",
- "\n",
- " encoder_frozen: bool = True # freeze encoder parameters\n",
- " ft_dropout: float = 0.25 # MLP layer dropout\n",
- "\n",
- " def get_loss_reduction_class(self) -> Type[MegatronLossReduction]:\n",
- " return RegressorLossReduction\n",
- "```"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "### 5 - Dataset"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "\n",
- "We will use a sample dataset for demonstration purposes. Create a dataset class by extending ```bionemo.esm2.model.finetune.dataset.InMemoryProteinDataset```. The `InMemoryProteinDataset` has a `classmethod` (`from_csv`) that reads data from a CSV file that has `sequences` and optionally `labels` columns. It is important to override the `transform_label()` method that returns a `torch.Tensor` containing the label in correct format. As an example we can use this method to add custom tokenization if `label` is a string."
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "\n",
- "The custom dataset class will be appropriate (found in ```bionemo.esm2.model.finetune.dataset.InMemorySingleValueDataset```) as it facilitates predicting on a single value. An excerpt from the class is shown below. This example dataset has a class method `from_csv()` that expects a `data_path` to a CSV file that has `sequences`, and `labels` columns.\n",
- "\n",
- "```python\n",
- "class InMemorySingleValueDataset(InMemoryProteinDataset):\n",
- " def __init__(\n",
- " self,\n",
- " sequences: pd.Series,\n",
- " labels: pd.Series | None = None,\n",
- " tokenizer: tokenizer.BioNeMoESMTokenizer = tokenizer.get_tokenizer(),\n",
- " seed: int = np.random.SeedSequence().entropy,\n",
- " ):\n",
- " super().__init__(sequences, labels, tokenizer, seed)\n",
- "\n",
- " def transform_label(self, label: float) -> Tensor:\n",
- " return torch.tensor([label], dtype=torch.float)\n",
- "```\n",
- "\n",
- "The `transform_label` method allows for custom transformation of raw labels by casting or tokenization and need to be adjusted based on the data. Here we use this method to create a `float` tensor of the regression value."
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "#### DataModule"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "To coordinate the creation of training, validation and testing datasets from your data, we need to use a `datamodule` class. To do this we can directly use or extend the ```ESM2FineTuneDataModule``` class (located at ```bionemo.esm2.model.finetune.datamodule.ESM2FineTuneDataModule```) which defines helpful abstract methods that use your dataset class.\n",
- "\n",
- "```python\n",
- "dataset = InMemorySingleValueDataset.from_csv(data_path)\n",
- "data_module = ESM2FineTuneDataModule(\n",
- " train_dataset=dataset,\n",
- " valid_dataset=dataset\n",
- " micro_batch_size=4, # size of a batch to be processed in a device\n",
- " global_batch_size=8, # size of batch across all devices. Should be multiple of micro_batch_size\n",
- ")\n",
- "```"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "In the next part of this tutorial we will prepare the input needed to run regression and per-token-classification fine-tuning examples."
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "## Setup and Assumptions"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "All commands should be executed inside the BioNeMo docker container, which has all ESM-2 dependencies pre-installed. For more information on how to build or pull the BioNeMo2 container, refer to the [Initialization Guide](../../getting-started/initialization-guide.md)."
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- " NOTE Some of the cells below generate long text output. We're using
%%capture --no-display --no-stderr cell_output
to suppress this output. Comment or delete this line in the cells below to restore full output.
"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "### Import Required Libraries"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 1,
- "metadata": {},
- "outputs": [],
- "source": [
- "%%capture --no-display --no-stderr cell_output\n",
- "\n",
- "import os\n",
- "import shutil\n",
- "import pandas as pd\n",
- "\n",
- "import warnings\n",
- "warnings.filterwarnings('ignore')\n",
- "warnings.simplefilter('ignore')"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "### Work Directory"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "Set the work directory to store data and results:\n",
- "\n",
- " NOTE We set the following to clean up the work directory created by this notebook
cleanup : bool = True
"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 2,
- "metadata": {},
- "outputs": [],
- "source": [
- "cleanup : bool = True"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 3,
- "metadata": {},
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "Directory '/workspace/bionemo2/esm2_finetune_tutorial' created.\n"
- ]
- }
- ],
- "source": [
- "work_dir=\"/workspace/bionemo2/esm2_finetune_tutorial\"\n",
- "\n",
- "if cleanup and os.path.exists(work_dir):\n",
- " shutil.rmtree(work_dir)\n",
- "\n",
- "if not os.path.exists(work_dir):\n",
- " os.makedirs(work_dir)\n",
- " print(f\"Directory '{work_dir}' created.\")\n",
- "else:\n",
- " print(f\"Directory '{work_dir}' already exists.\")"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "### Download Pre-trained Model Checkpoints"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "The following code will download the internally pre-trained model, `esm2/nv_8m:2.0`, from the NGC registry. Please refer to [ESM-2 Model Overview](../../../models/ESM-2/index.md) for a list of available checkpoints."
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 4,
- "metadata": {},
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "/home/ubuntu/.cache/bionemo/b4ea4d52eea8a25d2c2838617ff678f0da22d384cee195b0c192686816078dcd-esm2_8m_checkpoint.tar.gz.untar\n"
- ]
- }
- ],
- "source": [
- "from bionemo.core.data.load import load\n",
- "\n",
- "checkpoint_path = load(\"esm2/nv_8m:2.0\")\n",
- "print(checkpoint_path)"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "The above example is downloading an internally trained 8M ESM-2 model. The pre-trained checkpoints can be downloaded from NGC resources using either the following bash command or the `load` function in `bionemo.core.data.load` as shown above.\n",
- "\n",
- "```bash\n",
- "download_bionemo_data esm2/650m:2.0\n",
- "```\n",
- "\n",
- "which returns the checkpoint path (e.g. `.../.cache/bionemo/975d29ee980fcb08c97401bbdfdcf8ce-esm2_650M_nemo2.tar.gz.untar`)\n",
- "\n"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "## Fine-tuning"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "We can take advantage of the ESM2 fine-tuning script in ```bionemo.esm2.scripts.finetune_esm2``` or use the ```finetune_esm2``` executable the fine-tuning process given:\n",
- "\n",
- "- Pre-trained checkpoint of ESM2\n",
- "- Finetune config class name that configures the finetune model and loss reduction\n",
- "- Path to train and validation CSV data files\n",
- "- Dataset class name\n",
- "\n",
- "To get the full list of arguments to tune a finetuning run use:\n",
- "```bash\n",
- "finetune_esm2 --help \n",
- "```\n",
- "For a detailed description of training loop and the arguments please refer to the [ESM-2 Pretraining](./pretrain.md) tutorial."
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- " NOTE\n",
- "\n",
- "Due to Megatron limitations, the log produced by the training run iterates on steps/iterations and not epochs. Therefore, `Training epoch` counter stays at value zero while `iteration` and `global_step` increase during the course of training (example in the following).\n",
- "\n",
- "
\n",
- "Training epoch 0, iteration | ... | global_step: | reduced_train_loss: ... | val_loss: ...\n",
- "
\n",
- "\n",
- "to achieve the same epoch-based effect while training, please choose the number of training steps (`num_steps`) so that:\n",
- "\n",
- "
\n",
- "num_steps * global_batch_size = len(dataset) * desired_num_epochs\n",
- "
"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "### Regression"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "For the purposes of this demo, we'll assume dataset consists of small set of protein sequences with a target value of `len(sequence) / 100.0` as their labels:"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 5,
- "metadata": {},
- "outputs": [],
- "source": [
- "import pandas as pd\n",
- "\n",
- "artificial_sequence_data = [\n",
- " \"TLILGWSDKLGSLLNQLAIANESLGGGTIAVMAERDKEDMELDIGKMEFDFKGTSVI\",\n",
- " \"LYSGDHSTQGARFLRDLAENTGRAEYELLSLF\",\n",
- " \"GRFNVWLGGNESKIRQVLKAVKEIGVSPTLFAVYEKN\",\n",
- " \"DELTALGGLLHDIGKPVQRAGLYSGDHSTQGARFLRDLAENTGRAEYELLSLF\",\n",
- " \"KLGSLLNQLAIANESLGGGTIAVMAERDKEDMELDIGKMEFDFKGTSVI\",\n",
- " \"LFGAIGNAISAIHGQSAVEELVDAFVGGARISSAFPYSGDTYYLPKP\",\n",
- " \"LGGLLHDIGKPVQRAGLYSGDHSTQGARFLRDLAENTGRAEYELLSLF\",\n",
- " \"LYSGDHSTQGARFLRDLAENTGRAEYELLSLF\",\n",
- " \"ISAIHGQSAVEELVDAFVGGARISSAFPYSGDTYYLPKP\",\n",
- " \"SGSKASSDSQDANQCCTSCEDNAPATSYCVECSEPLCETCVEAHQRVKYTKDHTVRSTGPAKT\",\n",
- "]\n",
- "\n",
- "data = [(seq, len(seq)/100.0) for seq in artificial_sequence_data]\n",
- "\n",
- "# Create a DataFrame\n",
- "df = pd.DataFrame(data, columns=[\"sequences\", \"labels\"])\n",
- "\n",
- "# Save the DataFrame to a CSV file\n",
- "data_path = os.path.join(work_dir, \"regression_data.csv\")\n",
- "df.to_csv(data_path, index=False)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 6,
- "metadata": {},
- "outputs": [],
- "source": [
- "%%capture --no-display --no-stderr cell_output\n",
- "\n",
- "! finetune_esm2 \\\n",
- " --restore-from-checkpoint-path {checkpoint_path} \\\n",
- " --train-data-path {data_path} \\\n",
- " --valid-data-path {data_path} \\\n",
- " --config-class ESM2FineTuneSeqConfig \\\n",
- " --dataset-class InMemorySingleValueDataset \\\n",
- " --experiment-name \"regression\" \\\n",
- " --num-steps 50 \\\n",
- " --num-gpus 1 \\\n",
- " --val-check-interval 10 \\\n",
- " --log-every-n-steps 10 \\\n",
- " --lr 5e-3 \\\n",
- " --result-dir {work_dir} \\\n",
- " --micro-batch-size 2 \\\n",
- " --num-gpus 1 \\\n",
- " --precision \"bf16-mixed\"\n"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "The previous cell executes the finetuning and saves the checkpoints at the end of the run. The checkpoint path is logged at the end of the finetuning log file: \n",
- "\n",
- "```\n",
- "[NeMo I $$$$-$$-$$ 22:04:28 nemo_logging:393] Async checkpoint save for step 50 (/workspace/bionemo2/esm2_finetune_tutorial/regression/dev/checkpoints/checkpoint-step=49-consumed_samples=100.0-last-v1.ckpt) finalized successfully.\n",
- "```\n",
- "\n",
- "To avoid long text output from the previous cell, the log is captured and stored into the `cell_output` variable. To visualize the log file uncomment and execute the next cell:"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 7,
- "metadata": {},
- "outputs": [],
- "source": [
- "# print(cell_output.stdout)"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "We can now use the checkpoint stored in the previous step to run inference. We will drop the `.ckpt` from the checkpoint path and provide that to the `--checkpoint-path` argument of `infer_esm2` executable.\n",
- "\n",
- "The input `--data-path` for inference is a CSV file with `sequences` column. It is also required to provide the appropriate `--config-class` name to load the model from the checkpoint. For a detailed description of inference arguments please refer to the [ESM-2 Inference](./inference.ipynb) tutorial."
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 8,
- "metadata": {},
- "outputs": [],
- "source": [
- "# Create a DataFrame\n",
- "df = pd.DataFrame(artificial_sequence_data, columns=[\"sequences\"])\n",
- "\n",
- "# Save the DataFrame to a CSV file\n",
- "data_path = os.path.join(work_dir, \"sequences.csv\")\n",
- "df.to_csv(data_path, index=False)\n",
- "\n",
- "checkpoint_path = f\"{work_dir}/regression/dev/checkpoints/checkpoint-step=49-consumed_samples=100.0-last\"\n",
- "results_path = f\"{work_dir}/regression/infer/\""
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 9,
- "metadata": {},
- "outputs": [],
- "source": [
- "%%capture --no-display --no-stderr cell_output\n",
- "\n",
- "! infer_esm2 --checkpoint-path {checkpoint_path} \\\n",
- " --config-class ESM2FineTuneSeqConfig \\\n",
- " --data-path {data_path} \\\n",
- " --results-path {results_path} \\\n",
- " --micro-batch-size 3 \\\n",
- " --num-gpus 1 \\\n",
- " --precision \"bf16-mixed\" \\\n",
- " --include-embeddings \\\n",
- " --include-input-ids"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "The inference results are written into a `.pt` file which can be loaded using PyTorch library:"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 10,
- "metadata": {},
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "input_ids\ttorch.Size([10, 1024])\n",
- "embeddings\ttorch.Size([10, 320])\n",
- "regression_output\ttorch.Size([10, 1])\n"
- ]
- }
- ],
- "source": [
- "import torch\n",
- "results = torch.load(f\"{results_path}/predictions__rank_0.pt\")\n",
- "\n",
- "for key, val in results.items():\n",
- " if val is not None:\n",
- " print(f'{key}\\t{val.shape}')"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "### Toke-level Classification data"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "For this task we assign secondary structure label to each token in the sequence:"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 11,
- "metadata": {},
- "outputs": [],
- "source": [
- "secondary_structure_labels = [\n",
- " \"EEEECCCCCHHHHHHHHHHHHHHHCCCEEEEEECCCHHHHHHHHHCCCCCCCCCEEE\",\n",
- " \"CCCCCHHHHHHHHHHHHHHCCCCCHHHHHHCC\",\n",
- " \"HHHHHCCCCCHHHHHHHHHHHHHHCCCHHHHHHHHHH\",\n",
- " \"HHHHHHHHHHCCCHHHHHCCCCCCCCHHHHHHHHHHHHHHCCCCCHHHHHHCC\",\n",
- " \"CHHHHHHHHHHHHHHHCCCEEEEEECCCHHHHHHHHHCCCCCCCCCEEE\",\n",
- " \"HHHHHHHHHHHHHCHHHHHHHHHHHHCCCEECCCEEEECCEEEEECC\",\n",
- " \"HHHHHCCCHHHHHCCCCCCCCHHHHHHHHHHHHHHCCCCCHHHHHHCC\",\n",
- " \"CCCCCHHHHHHHHHHHHHHCCCCCHHHHHHCC\",\n",
- " \"HHHHHCHHHHHHHHHHHHCCCEECCCEEEECCEEEEECC\",\n",
- " \"CCCCCCCCCCCCCCCCCCCCCCCCCCEEECCCCEEECHHHHHHHHHCCCCCCCCEEECCCCCC\",\n",
- "]\n",
- "\n",
- "data = [(seq, label) for (seq, label) in zip(artificial_sequence_data, secondary_structure_labels)]\n",
- "\n",
- "# Create a DataFrame\n",
- "df = pd.DataFrame(data, columns=[\"sequences\", \"labels\"])\n",
- "\n",
- "# Save the DataFrame to a CSV file\n",
- "data_path = os.path.join(work_dir, \"token_classification_data.csv\")\n",
- "df.to_csv(data_path, index=False)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 12,
- "metadata": {},
- "outputs": [],
- "source": [
- "%%capture --no-display --no-stderr cell_output\n",
- "\n",
- "! finetune_esm2 \\\n",
- " --restore-from-checkpoint-path {checkpoint_path} \\\n",
- " --train-data-path {data_path} \\\n",
- " --valid-data-path {data_path} \\\n",
- " --config-class ESM2FineTuneTokenConfig \\\n",
- " --dataset-class InMemoryPerTokenValueDataset \\\n",
- " --experiment-name \"token_level_classification\" \\\n",
- " --num-steps 50 \\\n",
- " --num-gpus 1 \\\n",
- " --val-check-interval 10 \\\n",
- " --log-every-n-steps 10 \\\n",
- " --lr 5e-3 \\\n",
- " --result-dir {work_dir} \\\n",
- " --micro-batch-size 2 \\\n",
- " --num-gpus 1 \\\n",
- " --precision \"bf16-mixed\"\n"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "The previous cell executes the finetuning and saves the checkpoints at the end of the run. The checkpoint path is logged at the end of the finetuning log file: \n",
- "\n",
- "```\n",
- "[NeMo I $$$$-$$-$$ 22:16:46 nemo_logging:393] Async checkpoint save for step 50 (/workspace/bionemo2/esm2_finetune_tutorial/token_level_classification/dev/checkpoints/checkpoint-step=49-consumed_samples=100.0-last.ckpt) finalized successfully.\n",
- "```\n",
- "\n",
- "To avoid long text output from the previous cell, the log is captured and stored into the `cell_output` variable. To visualize the log file uncomment and execute the next cell:"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 13,
- "metadata": {},
- "outputs": [],
- "source": [
- "# print(cell_output.stdout)"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "We can now use the checkpoint stored in the previous step to run inference. We will drop the `.ckpt` from the checkpoint path and provide that to the `--checkpoint-path` argument of `infer_esm2` executable.\n",
- "\n",
- "The input `--data-path` for inference is a CSV file with `sequences` column. It is also required to provide the appropriate `--config-class` name to load the model from the checkpoint. For a detailed description of inference arguments please refer to the [ESM-2 Inference](./inference.ipynb) tutorial."
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 14,
- "metadata": {},
- "outputs": [],
- "source": [
- "# Create a DataFrame\n",
- "df = pd.DataFrame(artificial_sequence_data, columns=[\"sequences\"])\n",
- "\n",
- "# Save the DataFrame to a CSV file\n",
- "data_path = os.path.join(work_dir, \"sequences.csv\")\n",
- "df.to_csv(data_path, index=False)\n",
- "\n",
- "checkpoint_path = f\"{work_dir}/token_level_classification/dev/checkpoints/checkpoint-step=49-consumed_samples=100.0-last\"\n",
- "results_path = f\"{work_dir}/token_level_classification/infer/\""
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 15,
- "metadata": {},
- "outputs": [],
- "source": [
- "%%capture --no-display --no-stderr cell_output\n",
- "\n",
- "! infer_esm2 --checkpoint-path {checkpoint_path} \\\n",
- " --config-class ESM2FineTuneTokenConfig \\\n",
- " --data-path {data_path} \\\n",
- " --results-path {results_path} \\\n",
- " --micro-batch-size 3 \\\n",
- " --num-gpus 1 \\\n",
- " --precision \"bf16-mixed\" \\\n",
- " --include-embeddings \\\n",
- " --include-hiddens \\\n",
- " --include-input-ids"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "The inference results are written into a `.pt` file which can be loaded using PyTorch library:"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 16,
- "metadata": {},
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "hidden_states\ttorch.Size([10, 1024, 320])\n",
- "input_ids\ttorch.Size([10, 1024])\n",
- "embeddings\ttorch.Size([10, 320])\n",
- "classification_output\ttorch.Size([10, 1024, 3])\n"
- ]
- }
- ],
- "source": [
- "import torch\n",
- "results = torch.load(f\"{results_path}/predictions__rank_0.pt\")\n",
- "\n",
- "for key, val in results.items():\n",
- " if val is not None:\n",
- " print(f'{key}\\t{val.shape}')"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "We can use the label tokenizer to convert the classification output to class names. Note that for demonstration purposes we are using a small dataset of artificial sequences in this example. You may experience over-fitting and observe no change in the validation metrics. This amount of data and the short training run does not result in accurate predictions."
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 17,
- "metadata": {},
- "outputs": [],
- "source": [
- "from bionemo.esm2.data.tokenizer import get_tokenizer\n",
- "\n",
- "\n",
- "tokenizer = get_tokenizer()\n",
- "tokens = tokenizer.all_tokens\n",
- "aa_tokens = ['L', 'A', 'G', 'V', 'S', 'E', 'R', 'T', 'I', 'D', 'P', 'K', 'Q', 'N', 'F', 'Y', 'M', 'H', 'W', 'C']\n",
- "aa_indices = [i for i, token in enumerate(tokens) if token in aa_tokens]\n",
- "extra_indices = [i for i, token in enumerate(tokens) if token not in aa_tokens]\n",
- "\n",
- "input_ids = results['input_ids'] # b, s\n",
- "# mask where non-amino acid tokens are True\n",
- "mask = ~torch.isin(input_ids, torch.tensor(extra_indices))"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 18,
- "metadata": {},
- "outputs": [],
- "source": [
- "from bionemo.llm.data.label2id_tokenizer import Label2IDTokenizer\n",
- "\n",
- "label_tokenizer = Label2IDTokenizer()\n",
- "label_tokenizer = label_tokenizer.build_vocab(secondary_structure_labels)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 19,
- "metadata": {},
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "Predicted Secondary Structures:\n",
- "HHHHEEEEECCCCCCCCCCCCCCCEEEHHHHHHEEECCCCCCCCCEEEEEEEEEHHH\n",
- "EEEEECCCCCCCCCCCCCCEEEEECCCCCCEE\n",
- "CCCCCEEEEECCCCCCCCCCCCCEEEECCCCCCCCCC\n",
- "CCCCCCCCCCEEECCCCCEEEEEEEECCCCCCCCCCCCCCEEEEECCCCCCEE\n",
- "ECCCCCCCCCCCCCCCEEEHHHHHHEEECCCCCCCCCEEEEEEEEEHHH\n",
- "CCCCCCCCCCCCCECCCCCCCCCCCCEEEHHEEEHHHHEEHHHHHEE\n",
- "CCCCCEEECCCCCEEEEEEEECCCCCCCCCCCCCCEEEEECCCCCCEE\n",
- "EEEEECCCCCCCCCCCCCCEEEEECCCCCCEE\n",
- "CCCCCECCCCCCCCCCCCEEEHHEEEHHHHEEHHHHHEE\n",
- "EEEEEEEEEEEEEEEEEEEEEEEEEEHHHEEEEHHHECCCCCCCCCEEEEEEEEHHHEEEEEE\n"
- ]
- }
- ],
- "source": [
- "output_ids = torch.argmax(results[\"classification_output\"], dim=-1)\n",
- "\n",
- "print(\"Predicted Secondary Structures:\")\n",
- "for i in range(output_ids.shape[0]):\n",
- " ss_ids = output_ids[i][mask[i]]\n",
- " print(label_tokenizer.ids_to_text(ss_ids.tolist()))"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": []
- }
- ],
- "metadata": {
- "kernelspec": {
- "display_name": "Python 3",
- "language": "python",
- "name": "python3"
- },
- "language_info": {
- "codemirror_mode": {
- "name": "ipython",
- "version": 3
- },
- "file_extension": ".py",
- "mimetype": "text/x-python",
- "name": "python",
- "nbconvert_exporter": "python",
- "pygments_lexer": "ipython3",
- "version": "3.12.3"
- }
- },
- "nbformat": 4,
- "nbformat_minor": 2
-}
diff --git a/docs/docs/user-guide/examples/bionemo-esm2/finetune.md b/docs/docs/user-guide/examples/bionemo-esm2/finetune.md
new file mode 100644
index 0000000000..7968a22397
--- /dev/null
+++ b/docs/docs/user-guide/examples/bionemo-esm2/finetune.md
@@ -0,0 +1,263 @@
+# ESM-2 Fine-Tuning
+
+This readme serves as a demo for implementing ESM-2 Fine-tuning module, running a regression example and using the model for inference.
+
+The ESM-2 model is a transformer-based protein language model that has achieved state-of-the-art results in various protein-related tasks. When fine-tuning ESM2, the task head plays a crucial role. A task head refers to the additional layer or set of layers added on top of a pre-trained model, like the ESM-2 transformer-based protein language model, to adapt it for a specific downstream task. As a part of transfer learning, a pre-trained model is often utilized to learn generic features from a large-scale dataset. However, these features might not be directly applicable to the specific task at hand. By incorporating a task head, which consists of learnable parameters, the model can adapt and specialize to the target task. The task head serves as a flexible and adaptable component that learns task-specific representations by leveraging the pre-trained features as a foundation. Through fine-tuning, the task head enables the model to learn and extract task-specific patterns, improving performance and addressing the nuances of the downstream task. It acts as a critical bridge between the pre-trained model and the specific task, enabling efficient and effective transfer of knowledge.
+
+
+# Setup and Assumptions
+
+In this tutorial, we will demonstrate how to create a fine-tune module, train a regression task head, and use the fine-tuned model for inference.
+
+All commands should be executed inside the BioNeMo docker container, which has all ESM-2 dependencies pre-installed. This tutorial assumes that a copy of the BioNeMo framework repo exists on workstation or server and has been mounted inside the container at `/workspace/bionemo2`. (**Note**: This `WORKDIR` may be `/workspaces/bionemo-framework` if you are using the VSCode Dev Container.) For more information on how to build or pull the BioNeMo2 container, refer to the [Access and Startup](../../getting-started/access-startup.md).
+
+To successfully accomplish this we need to define some key classes:
+
+1. Loss Reduction Method - To compute the supervised fine-tuning loss.
+2. Fine-Tuned Model Head - Downstream task head model.
+3. Fine-Tuned Model - Model that combines ESM-2 with the task head model.
+4. Fine-Tuning Config - Configures the fine-tuning model and loss to use in the training and inference framework.
+5. Dataset - Training and inference datasets for ESM2.
+
+## 1 - Loss Reduction Class
+
+A class for calculating the supervised loss of the fine-tune model from targets. We inherit from Megatron Bert Masked Language Model Loss (`BERTMLMLossWithReduction`) and override the `forward()` pass to compute MSE loss of the regression head within a micro-batch. The `reduce()` method is used for computing the average over the micro-batches and is only used for logging.
+
+```python
+class RegressorLossReduction(BERTMLMLossWithReduction):
+ def forward(
+ self, batch: Dict[str, torch.Tensor], forward_out: Dict[str, torch.Tensor]
+ ) -> Tuple[torch.Tensor, Union[PerTokenLossDict, SameSizeLossDict]]:
+
+ targets = batch["labels"] # [b, 1]
+ regression_output = forward_out
+ loss = torch.nn.functional.mse_loss(regression_output, targets)
+ return loss, {"avg": loss}
+
+ def reduce(self, losses_reduced_per_micro_batch: Sequence[ReductionT]) -> torch.Tensor:
+ losses = torch.stack([loss["avg"] for loss in losses_reduced_per_micro_batch])
+ return losses.mean()
+```
+
+## 2 - Fine-Tuned Model Head
+
+An MLP class for sequence-level regression. This class inherits `MegatronModule` and uses the fine-tune config (`TransformerConfig`) to configure the regression head for the fine-tuned ESM-2 model.
+
+```python
+class MegatronMLPHead(MegatronModule):
+ def __init__(self, config: TransformerConfig):
+ super().__init__(config)
+ layer_sizes = [config.hidden_size, 256, 1]
+ self.linear_layers = torch.nn.ModuleList(
+ [torch.nn.Linear(i, o) for i, o in zip(layer_sizes[:-1], layer_sizes[1:])]
+ )
+ self.act = torch.nn.ReLU()
+ self.dropout = torch.nn.Dropout(p=config.ft_dropout)
+
+ def forward(self, hidden_states: torch.Tensor) -> List[torch.Tensor]:
+ ...
+```
+
+## 3 - Fine-Tuned Model
+
+A fine-tuned ESM-2 model class for token classification tasks. This class inherits from the `ESM2Model` class and adds the custom regression head `MegatronMLPHead` the we created in the previous step. Optionally one can freeze all or parts of the encoder by parsing through the model parameters in the model constructor.
+
+```python
+class ESM2FineTuneSeqModel(ESM2Model):
+ def __init__(self, config, *args, post_process: bool = True, return_embeddings: bool = False, **kwargs):
+ super().__init__(config, *args, post_process=post_process, return_embeddings=True, **kwargs)
+
+ # freeze encoder parameters
+ if config.encoder_frozen:
+ for _, param in self.named_parameters():
+ param.requires_grad = False
+
+ if post_process:
+ self.regression_head = MegatronMLPHead(config)
+
+ def forward(self, *args, **kwargs,):
+ output = super().forward(*args, **kwargs)
+ ...
+ regression_output = self.regression_head(embeddings)
+ return regression_output
+```
+
+## 4 - Fine-Tuning Config
+
+A `dataclass` that configures the fine-tuned ESM-2 model. In this example `ESM2FineTuneSeqConfig` inherits from `ESM2GenericConfig` and adds custom arguments to setup the fine-tuned model. The `configure_model()` method of this `dataclass` is called within the `Lightning` module to call the model constructor with the `dataclass` arguments.
+
+The common arguments among different fine-tuning tasks are
+
+- `model_cls`: The fine-tune model class (`ESM2FineTuneSeqModel`)
+- `initial_ckpt_path`: BioNeMo 2.0 ESM-2 pre-trained checkpoint
+- `initial_ckpt_skip_keys_with_these_prefixes`: skip keys when loading parameters from a checkpoint. Here we should not look for `regression_head` in the pre-trained checkpoint.
+- `get_loss_reduction_class()`: Implements selection of the appropriate `MegatronLossReduction` class, e.g. `bionemo.esm2.model.finetune.finetune_regressor.RegressorLossReduction`.
+
+```python
+
+@dataclass
+class ESM2FineTuneSeqConfig(ESM2GenericConfig[ESM2FineTuneSeqModel], iom.IOMixinWithGettersSetters):
+ model_cls: Type[ESM2FineTuneSeqModel] = ESM2FineTuneSeqModel
+ # The following checkpoint path is for nemo2 checkpoints. Config parameters not present in
+ # self.override_parent_fields will be loaded from the checkpoint and override those values here.
+ initial_ckpt_path: str | None = None
+ # typical case is fine-tune the base biobert that doesn't have this head. If you are instead loading a checkpoint
+ # that has this new head and want to keep using these weights, please drop this next line or set to []
+ initial_ckpt_skip_keys_with_these_prefixes: List[str] = field(default_factory=lambda: ["regression_head"])
+
+ encoder_frozen: bool = True # freeze encoder parameters
+ ft_dropout: float = 0.25 # MLP layer dropout
+
+ def get_loss_reduction_class(self) -> Type[MegatronLossReduction]:
+ return RegressorLossReduction
+```
+
+## 5 - Dataset
+
+We will use a sample dataset for demonstration purposes. Create a dataset class by extending from ```torch.utils.data.Dataset```. For the purposes of this demo, we'll assume dataset consists of small set of protein sequences with a target value of `len(sequence) / 100.0` as their labels.
+
+```python
+data = [
+ ("MVLSPADKTNVKAAWGKVGAHAGEYGAEALERH", 0.33),
+ ...
+]
+```
+
+Therefore, the custom BioNeMo dataset class will be appropriate (found in ```bionemo.esm2.model.finetune.finetune_regressor.InMemorySingleValueDataset```) as it facilitates predicting on a single value. An excerpt from the class is shown below. This example dataset expected a sequence of `Tuple` that hold `(sequence, target)` values. However, one can simply extend ```InMemorySingleValueDataset``` class in a similar way to customize your class for your data.
+
+```python
+class InMemorySingleValueDataset(Dataset):
+ def __init__(
+ self,
+ data: Sequence[Tuple[str, float]],
+ tokenizer: tokenizer.BioNeMoESMTokenizer = tokenizer.get_tokenizer(),
+ seed: int = np.random.SeedSequence().entropy,
+ ):
+```
+
+For any arbitrary data file formats, user can process the data into a list of tuples containing (sequence, label) and use this dataset class. Or override the dataset class to load their custom data files.
+
+To coordinate the creation of training, validation and testing datasets from your data, we need to use a `datamodule` class. To do this we can directly use or extend the ```ESM2FineTuneDataModule``` class (located at ```bionemo.esm2.model.finetune.datamodule.ESM2FineTuneDataModule```) which defines helpful abstract methods that use your dataset class.
+
+```python
+dataset = InMemorySingleValueDataset(data)
+data_module = ESM2FineTuneDataModule(
+ train_dataset=train_dataset,
+ valid_dataset=valid_dataset
+ micro_batch_size=4, # size of a batch to be processed in a device
+ global_batch_size=8, # size of batch across all devices. Should be multiple of micro_batch_size
+)
+```
+
+# Fine-Tuning the Regressor Task Head for ESM2
+
+Now we can put these five requirements together to fine-tune a regressor task head starting from a pre-trained 650M ESM-2 model (`pretrain_ckpt_path`). We can take advantage of a simple training loop in ```bionemo.esm2.model.fnetune.train``` and use the ```train_model()`` function to start the fine-tuning process in the following.
+
+```python
+# create a List[Tuple] with (sequence, target) values
+artificial_sequence_data = [
+ "TLILGWSDKLGSLLNQLAIANESLGGGTIAVMAERDKEDMELDIGKMEFDFKGTSVI",
+ "LYSGDHSTQGARFLRDLAENTGRAEYELLSLF",
+ "GRFNVWLGGNESKIRQVLKAVKEIGVSPTLFAVYEKN",
+ "DELTALGGLLHDIGKPVQRAGLYSGDHSTQGARFLRDLAENTGRAEYELLSLF",
+ "KLGSLLNQLAIANESLGGGTIAVMAERDKEDMELDIGKMEFDFKGTSVI",
+ "LFGAIGNAISAIHGQSAVEELVDAFVGGARISSAFPYSGDTYYLPKP",
+ "LGGLLHDIGKPVQRAGLYSGDHSTQGARFLRDLAENTGRAEYELLSLF",
+ "LYSGDHSTQGARFLRDLAENTGRAEYELLSLF",
+ "ISAIHGQSAVEELVDAFVGGARISSAFPYSGDTYYLPKP",
+ "SGSKASSDSQDANQCCTSCEDNAPATSYCVECSEPLCETCVEAHQRVKYTKDHTVRSTGPAKT",
+]
+
+data = [(seq, len(seq)/100.0) for seq in artificial_sequence_data]
+
+# we are training and validating on the same dataset for simplicity
+dataset = InMemorySingleValueDataset(data)
+data_module = ESM2FineTuneDataModule(train_dataset=dataset, valid_dataset=dataset)
+
+experiment_name = "finetune_regressor"
+n_steps_train = 50
+seed = 42
+
+# To download a 650M pre-trained ESM2 model
+pretrain_ckpt_path = load("esm2/650m:2.0")
+
+config = ESM2FineTuneSeqConfig(
+ initial_ckpt_path=str(pretrain_ckpt_path)
+)
+
+checkpoint, metrics, trainer = train_model(
+ experiment_name=experiment_name,
+ experiment_dir=Path(experiment_results_dir), # new checkpoint will land in a subdir of this
+ config=config, # same config as before since we are just continuing training
+ data_module=data_module,
+ n_steps_train=n_steps_train,
+)
+```
+
+This example is fully implemented in ```bionemo.esm2.model.finetune.train``` and can be executed by:
+```bash
+python -m bionemo.esm2.model.finetune.train
+```
+
+## Notes
+1. The above example is fine-tuning a 650M ESM-2 model. The pre-trained checkpoints can be downloaded from NGC resources using either the following bash command or the `load` function in `bionemo.core.data.load` as shown above.
+ ```bash
+ download_bionemo_data esm2/650m:2.0
+ ```
+ and pass the output path (e.g. `.../.cache/bionemo/975d29ee980fcb08c97401bbdfdcf8ce-esm2_650M_nemo2.tar.gz.untar`) as an argument into `initial_ckpt_path` while setting the config object:
+ ```python
+ config = ESM2FineTuneSeqConfig(
+ initial_ckpt_path=str(pretrain_ckpt_path)
+ )
+ ```
+2. Due to Megatron limitations, the log produced by the training run iterates on steps/iterations and not epochs. Therefore, `Training epoch` counter stays at value zero while `iteration` and `global_ste`p increase during the course of training (example in the following).
+ ```bash
+ Training epoch 0, iteration | ... | global_step: | reduced_train_loss: ... | val_loss: ...
+ ```
+ to achieve the same epoch-based effect while training, please choose the number of training steps (`n_steps_train`) so that:
+ ```bash
+ n_steps_train * global_batch_size = len(dataset) * desired_num_epochs
+ ```
+3. We are using a small dataset of artificial sequences as our fine-tuning data in this example. You may experience over-fitting and observe no change in the validation metrics.
+
+# Fine-Tuned ESM-2 Model Inference
+Now we can use ```bionemo.esm2.model.finetune.train.infer``` to run inference on an example prediction dataset.
+Record the checkpoint path reported at the end of the finetuning run, after executing `python -m bionemo.esm2.model.finetune.train` (e.g. `/tmp/tmp1b5wlnba/finetune_regressor/checkpoints/finetune_regressor--reduced_train_loss=0.0016-epoch=0-last`) and use that as an argument to inference script (`--checkpoint-path`).
+
+We download a CSV example dataset of articical sequences for this inference example. Please refer to [ESM-2 Inference](./inference) tutorial for detailed explanation of the arguments and how to create your own CSV file.
+
+```bash
+mkdir -p $WORKDIR/esm2_finetune_tutorial
+
+# download sample data CSV for inference
+DATA_PATH=$(download_bionemo_data esm2/testdata_esm2_infer:2.0)
+RESULTS_PATH=$WORKDIR/esm2_finetune_tutorial/
+
+infer_esm2 --checkpoint-path \
+ --data-path $DATA_PATH \
+ --results-path $RESULTS_PATH \
+ --config-class ESM2FineTuneSeqConfig
+```
+
+This will create a result `.pt` file under `$WORKDIR/esm2_finetune_tutorial/predictions__rank_0.pt` which can be loaded via PyTorch library in python environment:
+
+```python
+import torch
+
+# Set the path to results file e.g. /workspace/bionemo2/esm2_finetune_tutorial/predictions__rank_0.pt
+# results_path = /workspace/bionemo2/esm2_finetune_tutorial/predictions__rank_0.pt
+results = torch.load(results_path)
+
+# results is a python dict which includes the following result tensors for this example:
+# results['regression_output'] is a tensor with shape: torch.Size([10, 1])
+```
+
+## Notes
+- ESM2 Inference module takes the `--checkpoint-path` and `--config-class` arguments to create a config object by pointing the path in `initial_ckpt_path`. Since we need to load all the parameters from this checkpoint (and don't skip the head) we reset the `initial_ckpt_skip_keys_with_these_prefixes` in this config.
+
+ ```python
+ config = ESM2FineTuneSeqConfig(
+ initial_ckpt_path = ,
+ initial_ckpt_skip_keys_with_these_prefixes: List[str] = field(default_factory=list)
+ )
+ ```
diff --git a/docs/docs/user-guide/examples/bionemo-esm2/inference.ipynb b/docs/docs/user-guide/examples/bionemo-esm2/inference.ipynb
index f487e73011..5dfd17964f 100644
--- a/docs/docs/user-guide/examples/bionemo-esm2/inference.ipynb
+++ b/docs/docs/user-guide/examples/bionemo-esm2/inference.ipynb
@@ -141,40 +141,11 @@
"execution_count": 4,
"metadata": {},
"outputs": [
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- "Downloading data from 'nvidia/clara/esm2nv650m:2.0' to file '/home/ubuntu/.cache/bionemo/0798767e843e3d54315aef91934d28ae7d8e93c2849d5fcfbdf5fac242013997-esm2_650M_nemo2.tar.gz'.\n"
- ]
- },
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "{\n",
- " \"download_end\": \"2025-01-14 22:01:24\",\n",
- " \"download_start\": \"2025-01-14 22:01:05\",\n",
- " \"download_time\": \"18s\",\n",
- " \"files_downloaded\": 1,\n",
- " \"local_path\": \"/home/ubuntu/.cache/bionemo/tmpfj1e52vw/esm2nv650m_v2.0\",\n",
- " \"size_downloaded\": \"1.12 GB\",\n",
- " \"status\": \"COMPLETED\"\n",
- "}\n"
- ]
- },
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- "Untarring contents of '/home/ubuntu/.cache/bionemo/0798767e843e3d54315aef91934d28ae7d8e93c2849d5fcfbdf5fac242013997-esm2_650M_nemo2.tar.gz' to '/home/ubuntu/.cache/bionemo/0798767e843e3d54315aef91934d28ae7d8e93c2849d5fcfbdf5fac242013997-esm2_650M_nemo2.tar.gz.untar'\n"
- ]
- },
{
"name": "stdout",
"output_type": "stream",
"text": [
- "/home/ubuntu/.cache/bionemo/0798767e843e3d54315aef91934d28ae7d8e93c2849d5fcfbdf5fac242013997-esm2_650M_nemo2.tar.gz.untar\n"
+ "/home/bionemo/.cache/bionemo/0798767e843e3d54315aef91934d28ae7d8e93c2849d5fcfbdf5fac242013997-esm2_650M_nemo2.tar.gz.untar\n"
]
}
],
@@ -197,7 +168,7 @@
"metadata": {},
"source": [
"\n",
- "We use the `InMemoryProteinDataset` class to load the protein sequence data from a `.csv` file. This data file should at least have a `sequences` column and can optionally have a `labels` column used for fine-tuning applications. Here is an example of how to create your own inference input data using a list of sequences in python:"
+ "We use the `InMemoryCSVDataset` class to load the protein sequence data from a `.csv` file. This data file should at least have a `sequences` column and can optionally have a `labels` column used for fine-tuning applications. Here is an example of how to create your own inference input data using a list of sequences in python:"
]
},
{
@@ -267,12 +238,12 @@
"name": "stdout",
"output_type": "stream",
"text": [
- "2025-01-14 22:01:45 - faiss.loader - INFO - Loading faiss with AVX512 support.\n",
- "2025-01-14 22:01:45 - faiss.loader - INFO - Successfully loaded faiss with AVX512 support.\n",
- "[NeMo W 2025-01-14 22:01:46 nemo_logging:405] /usr/local/lib/python3.12/dist-packages/pydub/utils.py:170: RuntimeWarning: Couldn't find ffmpeg or avconv - defaulting to ffmpeg, but may not work\n",
+ "2024-12-16 20:19:23 - faiss.loader - INFO - Loading faiss with AVX512 support.\n",
+ "2024-12-16 20:19:23 - faiss.loader - INFO - Successfully loaded faiss with AVX512 support.\n",
+ "[NeMo W 2024-12-16 20:19:24 nemo_logging:361] /usr/local/lib/python3.10/dist-packages/pydub/utils.py:170: RuntimeWarning: Couldn't find ffmpeg or avconv - defaulting to ffmpeg, but may not work\n",
" warn(\"Couldn't find ffmpeg or avconv - defaulting to ffmpeg, but may not work\", RuntimeWarning)\n",
" \n",
- "[NeMo W 2025-01-14 22:01:46 nemo_logging:405] /usr/local/lib/python3.12/dist-packages/pyannote/core/notebook.py:134: MatplotlibDeprecationWarning: The get_cmap function was deprecated in Matplotlib 3.7 and will be removed in 3.11. Use ``matplotlib.colormaps[name]`` or ``matplotlib.colormaps.get_cmap()`` or ``pyplot.get_cmap()`` instead.\n",
+ "[NeMo W 2024-12-16 20:19:24 nemo_logging:361] /usr/local/lib/python3.10/dist-packages/pyannote/core/notebook.py:134: MatplotlibDeprecationWarning: The get_cmap function was deprecated in Matplotlib 3.7 and will be removed two minor releases later. Use ``matplotlib.colormaps[name]`` or ``matplotlib.colormaps.get_cmap(obj)`` instead.\n",
" cm = get_cmap(\"Set1\")\n",
" \n",
"usage: infer_esm2 [-h] --checkpoint-path CHECKPOINT_PATH --data-path DATA_PATH\n",
@@ -562,7 +533,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
- "version": "3.12.3"
+ "version": "3.10.12"
}
},
"nbformat": 4,
diff --git a/docs/docs/user-guide/getting-started/development.md b/docs/docs/user-guide/getting-started/development.md
index ae8a0997f7..ce97a78cbe 100644
--- a/docs/docs/user-guide/getting-started/development.md
+++ b/docs/docs/user-guide/getting-started/development.md
@@ -136,7 +136,7 @@ of the model. The fine-tuning steps will be application-specific, but a general
6. **Run inference**: Once the model is fine-tuned, use it to make predictions on new, unseen data.
For more information on fine-tuning a model, refer to the [ESM-2 Fine-tuning
-Tutorial](../examples/bionemo-esm2/finetune.ipynb).
+Tutorial](../examples/bionemo-esm2/finetune.md).
## Advanced Developer Documentation
diff --git a/sub-packages/bionemo-esm2/pyproject.toml b/sub-packages/bionemo-esm2/pyproject.toml
index 4acce854fa..aa8f7715ed 100644
--- a/sub-packages/bionemo-esm2/pyproject.toml
+++ b/sub-packages/bionemo-esm2/pyproject.toml
@@ -22,7 +22,6 @@ bionemo-esm2-train= "bionemo.esm2.run.main:main"
bionemo-esm2-recipe= "bionemo.esm2.run.recipes:main"
infer_esm2 = "bionemo.esm2.scripts.infer_esm2:infer_esm2_entrypoint"
train_esm2 = "bionemo.esm2.scripts.train_esm2:train_esm2_entrypoint"
-finetune_esm2 = "bionemo.esm2.scripts.finetune_esm2:finetune_esm2_entrypoint"
# Make sure that the tokenizer files are included along with the python files during installation.
[tool.setuptools.package-data]
diff --git a/sub-packages/bionemo-esm2/src/bionemo/esm2/model/finetune/datamodule.py b/sub-packages/bionemo-esm2/src/bionemo/esm2/model/finetune/datamodule.py
index 7104f64373..09526572ef 100644
--- a/sub-packages/bionemo-esm2/src/bionemo/esm2/model/finetune/datamodule.py
+++ b/sub-packages/bionemo-esm2/src/bionemo/esm2/model/finetune/datamodule.py
@@ -15,27 +15,119 @@
import functools
-from typing import Literal, Union
+import os
+from typing import Literal, Sequence, Tuple, Union
+import numpy as np
+import pandas as pd
+import torch
+import torch.utils.data
from lightning.pytorch.utilities.types import EVAL_DATALOADERS, TRAIN_DATALOADERS
from nemo.lightning.data import WrappedDataLoader
from nemo.lightning.pytorch.plugins import MegatronDataSampler
from nemo.utils import logging
+from torch import Tensor
+from torch.utils.data import Dataset
from bionemo.core.data.multi_epoch_dataset import IdentityMultiEpochDatasetWrapper, MultiEpochDatasetResampler
from bionemo.esm2.data import tokenizer
-from bionemo.esm2.model.finetune.dataset import (
- InMemoryPerTokenValueDataset,
- InMemoryProteinDataset,
- InMemorySingleValueDataset,
-)
+from bionemo.esm2.model.finetune.finetune_regressor import InMemorySingleValueDataset
+from bionemo.esm2.model.finetune.finetune_token_classifier import InMemoryPerTokenValueDataset
from bionemo.llm.data import collate
from bionemo.llm.data.datamodule import MegatronDataModule
+from bionemo.llm.data.types import BertSample
from bionemo.llm.utils.datamodule_utils import infer_num_samples
Mode = Literal["train", "validation", "test", "predict"]
-DATASET_TYPES = Union[InMemoryPerTokenValueDataset, InMemorySingleValueDataset, InMemoryProteinDataset, None]
+
+
+class InMemoryCSVDataset(Dataset):
+ """An in-memory dataset that tokenize strings into BertSample instances."""
+
+ def __init__(
+ self,
+ data_path: str | os.PathLike,
+ tokenizer: tokenizer.BioNeMoESMTokenizer = tokenizer.get_tokenizer(),
+ seed: int = np.random.SeedSequence().entropy, # type: ignore
+ ):
+ """Initializes a dataset for single-value regression fine-tuning.
+
+ This is an in-memory dataset that does not apply masking to the sequence. But keeps track of in the
+ dataset sequences provided.
+
+ Args:
+ data_path (str | os.PathLike): A path to the CSV file containing sequences.
+ labels (Optional[Sequence[float | str]]): An optional sequence of labels with 1:1 mapping to sequences.
+ tokenizer (tokenizer.BioNeMoESMTokenizer, optional): The tokenizer to use. Defaults to tokenizer.get_tokenizer().
+ seed: Random seed for reproducibility. This seed is mixed with the index of the sample to retrieve to ensure
+ that __getitem__ is deterministic, but can be random across different runs. If None, a random seed is
+ generated.
+ """
+ self.sequences, self.labels = self.load_data(data_path)
+
+ self.seed = seed
+ self._len = len(self.sequences)
+ self.tokenizer = tokenizer
+
+ def __len__(self) -> int:
+ """The size of the dataset."""
+ return self._len
+
+ def __getitem__(self, index: int) -> BertSample:
+ """Obtains the BertSample at the given index."""
+ sequence = self.sequences[index]
+ tokenized_sequence = self._tokenize(sequence)
+
+ label = tokenized_sequence if len(self.labels) == 0 else torch.Tensor([self.labels[index]])
+ # Overall mask for a token being masked in some capacity - either mask token, random token, or left as-is
+ loss_mask = ~torch.isin(tokenized_sequence, Tensor(self.tokenizer.all_special_ids))
+
+ return {
+ "text": tokenized_sequence,
+ "types": torch.zeros_like(tokenized_sequence, dtype=torch.int64),
+ "attention_mask": torch.ones_like(tokenized_sequence, dtype=torch.int64),
+ "labels": label,
+ "loss_mask": loss_mask,
+ "is_random": torch.zeros_like(tokenized_sequence, dtype=torch.int64),
+ }
+
+ def load_data(self, csv_path: str | os.PathLike) -> Tuple[Sequence, Sequence]:
+ """Loads data from a CSV file, returning sequences and optionally labels.
+
+ This method should be implemented by subclasses to process labels for their specific dataset.
+
+ Args:
+ csv_path (str | os.PathLike): The path to the CSV file containing the data.
+ The file is expected to have at least one column named 'sequence'. A 'label' column is optional.
+
+ Returns:
+ Tuple[Sequence, Sequence]: A tuple where the first element is a list of sequences and the second element is
+ a list of labels. If the 'label' column is not present, an empty list is returned for labels.
+ """
+ df = pd.read_csv(csv_path)
+ sequences = df["sequences"].tolist()
+
+ if "labels" in df.columns:
+ labels = df["labels"].tolist()
+ else:
+ labels = []
+ return sequences, labels
+
+ def _tokenize(self, sequence: str) -> Tensor:
+ """Tokenize a protein sequence.
+
+ Args:
+ sequence: The protein sequence.
+
+ Returns:
+ The tokenized sequence.
+ """
+ tensor = self.tokenizer.encode(sequence, add_special_tokens=True, return_tensors="pt")
+ return tensor.flatten() # type: ignore
+
+
+DATASET_TYPES = Union[InMemoryPerTokenValueDataset, InMemorySingleValueDataset, InMemoryCSVDataset, None]
class ESM2FineTuneDataModule(MegatronDataModule):
diff --git a/sub-packages/bionemo-esm2/src/bionemo/esm2/model/finetune/dataset.py b/sub-packages/bionemo-esm2/src/bionemo/esm2/model/finetune/dataset.py
deleted file mode 100644
index 542854548d..0000000000
--- a/sub-packages/bionemo-esm2/src/bionemo/esm2/model/finetune/dataset.py
+++ /dev/null
@@ -1,221 +0,0 @@
-# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
-# SPDX-License-Identifier: LicenseRef-Apache2
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-
-import os
-from typing import Sequence
-
-import numpy as np
-import pandas as pd
-import torch
-import torch.utils.data
-from torch import Tensor
-from torch.utils.data import Dataset
-
-from bionemo.esm2.data import tokenizer
-from bionemo.llm.data.collate import MLM_LOSS_IGNORE_INDEX
-from bionemo.llm.data.label2id_tokenizer import Label2IDTokenizer
-from bionemo.llm.data.types import BertSample
-
-
-__all__: Sequence[str] = (
- "InMemoryProteinDataset",
- "InMemorySingleValueDataset",
- "InMemoryPerTokenValueDataset",
-)
-
-
-class InMemoryProteinDataset(Dataset):
- """An in-memory dataset that tokenize strings into BertSample instances."""
-
- def __init__(
- self,
- sequences: pd.Series,
- labels: pd.Series | None = None,
- tokenizer: tokenizer.BioNeMoESMTokenizer = tokenizer.get_tokenizer(),
- seed: int = np.random.SeedSequence().entropy, # type: ignore
- ):
- """Initializes a dataset of protein sequences.
-
- This is an in-memory dataset that does not apply masking to the sequence. But keeps track of in the
- dataset sequences provided.
-
- Args:
- sequences (pd.Series): A pandas Series containing protein sequences.
- labels (pd.Series, optional): A pandas Series containing labels. Defaults to None.
- tokenizer (tokenizer.BioNeMoESMTokenizer, optional): The tokenizer to use. Defaults to tokenizer.get_tokenizer().
- seed: Random seed for reproducibility. This seed is mixed with the index of the sample to retrieve to ensure
- that __getitem__ is deterministic, but can be random across different runs. If None, a random seed is
- generated.
- """
- self.sequences = sequences
- self.labels = labels
-
- self.seed = seed
- self._len = len(self.sequences)
- self.tokenizer = tokenizer
-
- @classmethod
- def from_csv(
- cls, csv_path: str | os.PathLike, tokenizer: tokenizer.BioNeMoESMTokenizer = tokenizer.get_tokenizer()
- ):
- """Class method to create a ProteinDataset instance from a CSV file."""
- df = pd.read_csv(csv_path)
-
- # Validate presence of required columns
- if "sequences" not in df.columns:
- raise KeyError("The CSV must contain a 'sequences' column.")
-
- sequences = df["sequences"]
- labels = df["labels"] if "labels" in df.columns else None
- return cls(sequences, labels, tokenizer)
-
- def __len__(self) -> int:
- """The size of the dataset."""
- return self._len
-
- def __getitem__(self, index: int) -> BertSample:
- """Obtains the BertSample at the given index."""
- sequence = self.sequences[index]
- tokenized_sequence = self._tokenize(sequence)
-
- label = tokenized_sequence if self.labels is None else self.transform_label(self.labels.iloc[index])
- # Overall mask for a token being masked in some capacity - either mask token, random token, or left as-is
- loss_mask = ~torch.isin(tokenized_sequence, Tensor(self.tokenizer.all_special_ids))
-
- return {
- "text": tokenized_sequence,
- "types": torch.zeros_like(tokenized_sequence, dtype=torch.int64),
- "attention_mask": torch.ones_like(tokenized_sequence, dtype=torch.int64),
- "labels": label,
- "loss_mask": loss_mask,
- "is_random": torch.zeros_like(tokenized_sequence, dtype=torch.int64),
- }
-
- def _tokenize(self, sequence: str) -> Tensor:
- """Tokenize a protein sequence.
-
- Args:
- sequence: The protein sequence.
-
- Returns:
- The tokenized sequence.
- """
- tensor = self.tokenizer.encode(sequence, add_special_tokens=True, return_tensors="pt")
- return tensor.flatten() # type: ignore
-
- def transform_label(self, label):
- """Transform the label.
-
- This method should be implemented by subclass if label needs additional transformation.
-
- Args:
- label: label to be transformed
-
- Returns:
- transformed_label
- """
- return label
-
-
-class InMemorySingleValueDataset(InMemoryProteinDataset):
- """An in-memory dataset that tokenizes strings into BertSample instances."""
-
- def __init__(
- self,
- sequences: pd.Series,
- labels: pd.Series | None = None,
- tokenizer: tokenizer.BioNeMoESMTokenizer = tokenizer.get_tokenizer(),
- seed: int = np.random.SeedSequence().entropy, # type: ignore
- ):
- """Initializes a dataset for single-value regression fine-tuning.
-
- This is an in-memory dataset that does not apply masking to the sequence. But keeps track of in the
- dataset sequences provided.
-
- Args:
- sequences (pd.Series): A pandas Series containing protein sequences.
- labels (pd.Series, optional): A pandas Series containing labels. Defaults to None.
- tokenizer (tokenizer.BioNeMoESMTokenizer, optional): The tokenizer to use. Defaults to tokenizer.get_tokenizer().
- seed: Random seed for reproducibility. This seed is mixed with the index of the sample to retrieve to ensure
- that __getitem__ is deterministic, but can be random across different runs. If None, a random seed is
- generated.
- """
- super().__init__(sequences, labels, tokenizer, seed)
-
- def transform_label(self, label: float) -> Tensor:
- """Transform the regression label.
-
- Args:
- label: regression value
-
- Returns:
- tokenized label
- """
- return torch.tensor([label], dtype=torch.float)
-
-
-class InMemoryPerTokenValueDataset(InMemoryProteinDataset):
- """An in-memory dataset of labeled strings, which are tokenized on demand."""
-
- def __init__(
- self,
- sequences: pd.Series,
- labels: pd.Series | None = None,
- tokenizer: tokenizer.BioNeMoESMTokenizer = tokenizer.get_tokenizer(),
- seed: int = np.random.SeedSequence().entropy, # type: ignore
- ):
- """Initializes a dataset for per-token classification fine-tuning.
-
- This is an in-memory dataset that does not apply masking to the sequence. But keeps track of in the
- dataset sequences provided.
-
- Args:
- sequences (pd.Series): A pandas Series containing protein sequences.
- labels (pd.Series, optional): A pandas Series containing labels. Defaults to None.
- tokenizer (tokenizer.BioNeMoESMTokenizer, optional): The tokenizer to use. Defaults to tokenizer.get_tokenizer().
- seed: Random seed for reproducibility. This seed is mixed with the index of the sample to retrieve to ensure
- that __getitem__ is deterministic, but can be random across different runs. If None, a random seed is
- generated.
- """
- super().__init__(sequences, labels, tokenizer, seed)
- label_tokenizer = Label2IDTokenizer()
- self.label_tokenizer = label_tokenizer.build_vocab("CHE")
- self.label_cls_eos_id = MLM_LOSS_IGNORE_INDEX
-
- def transform_label(self, label: str) -> Tensor:
- """Transform the sequence label by tokenizing them.
-
- This method tokenizes the secondary structure token sequences.
-
- Args:
- label: secondary structure token sequences to be transformed
-
- Returns:
- tokenized label
- """
- label_ids = torch.tensor(self.label_tokenizer.text_to_ids(label))
-
- # # for multi-label classification with BCEWithLogitsLoss
- # tokenized_labels = torch.nn.functional.one_hot(label_ids, num_classes=self.label_tokenizer.vocab_size)
- # cls_eos = torch.full((1, self.label_tokenizer.vocab_size), self.label_cls_eos_id, dtype=tokenized_labels.dtype)
-
- # for multi-class (mutually exclusive) classification with CrossEntropyLoss
- tokenized_labels = label_ids
- cls_eos = torch.tensor([self.label_cls_eos_id], dtype=tokenized_labels.dtype)
-
- # add cls / eos label ids with padding value -100 to have the same shape as tokenized_sequence
- labels = torch.cat((cls_eos, tokenized_labels, cls_eos))
- return labels
diff --git a/sub-packages/bionemo-esm2/src/bionemo/esm2/model/finetune/finetune_regressor.py b/sub-packages/bionemo-esm2/src/bionemo/esm2/model/finetune/finetune_regressor.py
index 27d3375864..f63a194190 100644
--- a/sub-packages/bionemo-esm2/src/bionemo/esm2/model/finetune/finetune_regressor.py
+++ b/sub-packages/bionemo-esm2/src/bionemo/esm2/model/finetune/finetune_regressor.py
@@ -17,13 +17,17 @@
from dataclasses import dataclass, field
from typing import Dict, List, Sequence, Tuple, Type
+import numpy as np
import torch
from megatron.core import parallel_state
from megatron.core.transformer.module import MegatronModule
from megatron.core.transformer.transformer_config import TransformerConfig
from torch import Tensor
+from torch.utils.data import Dataset
from bionemo.esm2.api import ESM2GenericConfig, ESM2Model
+from bionemo.esm2.data import tokenizer
+from bionemo.llm.data.types import BertSample
from bionemo.llm.model.biobert.model import BioBertOutput
from bionemo.llm.model.loss import BERTMLMLossWithReduction, PerTokenLossDict, SameSizeLossDict
from bionemo.llm.utils import iomixin_utils as iom
@@ -37,6 +41,7 @@
"MegatronMLPHead",
"ESM2FineTuneSeqModel",
"ESM2FineTuneSeqConfig",
+ "InMemorySingleValueDataset",
)
@@ -173,3 +178,61 @@ class ESM2FineTuneSeqConfig(
def get_loss_reduction_class(self) -> Type[RegressorLossReduction]:
"""Returns RegressorLossReduction class."""
return RegressorLossReduction
+
+
+class InMemorySingleValueDataset(Dataset):
+ """An in-memory dataset that tokenizes strings into BertSample instances."""
+
+ def __init__(
+ self,
+ data: Sequence[Tuple[str, float]],
+ tokenizer: tokenizer.BioNeMoESMTokenizer = tokenizer.get_tokenizer(),
+ seed: int = np.random.SeedSequence().entropy, # type: ignore
+ ):
+ """Initializes a dataset for single-value regression fine-tuning.
+
+ This is an in-memory dataset that does not apply masking to the sequence.
+
+ Args:
+ data (Sequence[Tuple[str, float]]): A sequence of tuples containing the sequence and target data.
+ tokenizer (tokenizer.BioNeMoESMTokenizer, optional): The tokenizer to use. Defaults to tokenizer.get_tokenizer().
+ seed: Random seed for reproducibility. This seed is mixed with the index of the sample to retrieve to ensure
+ that __getitem__ is deterministic, but can be random across different runs. If None, a random seed is
+ generated.
+ """
+ self.data = data
+ self.seed = seed
+ self._len = len(self.data)
+ self.tokenizer = tokenizer
+
+ def __len__(self) -> int:
+ """The size of the dataset."""
+ return self._len
+
+ def __getitem__(self, index: int) -> BertSample:
+ """Obtains the BertSample at the given index."""
+ sequence, target = self.data[index]
+ tokenized_sequence = self._tokenize(sequence)
+ # Overall mask for a token being masked in some capacity - either mask token, random token, or left as-is
+ loss_mask = ~torch.isin(tokenized_sequence, Tensor(self.tokenizer.all_special_ids))
+
+ return {
+ "text": tokenized_sequence,
+ "types": torch.zeros_like(tokenized_sequence, dtype=torch.int64),
+ "attention_mask": torch.ones_like(tokenized_sequence, dtype=torch.int64),
+ "labels": torch.tensor([target], dtype=torch.float),
+ "loss_mask": loss_mask,
+ "is_random": torch.zeros_like(tokenized_sequence, dtype=torch.int64),
+ }
+
+ def _tokenize(self, sequence: str) -> Tensor:
+ """Tokenize a protein sequence.
+
+ Args:
+ sequence: The protein sequence.
+
+ Returns:
+ The tokenized sequence.
+ """
+ tensor = self.tokenizer.encode(sequence, add_special_tokens=True, return_tensors="pt")
+ return tensor.flatten() # type: ignore
diff --git a/sub-packages/bionemo-esm2/src/bionemo/esm2/model/finetune/finetune_token_classifier.py b/sub-packages/bionemo-esm2/src/bionemo/esm2/model/finetune/finetune_token_classifier.py
index fe67cf2ac8..b0991669d8 100644
--- a/sub-packages/bionemo-esm2/src/bionemo/esm2/model/finetune/finetune_token_classifier.py
+++ b/sub-packages/bionemo-esm2/src/bionemo/esm2/model/finetune/finetune_token_classifier.py
@@ -15,15 +15,21 @@
from dataclasses import dataclass, field
-from typing import Dict, List, Sequence, Tuple, Type
+from typing import List, Sequence, Tuple, Type, TypedDict
+import numpy as np
import torch
from megatron.core import parallel_state
from megatron.core.transformer.module import MegatronModule
from megatron.core.transformer.transformer_config import TransformerConfig
from torch import Tensor
+from torch.utils.data import Dataset
from bionemo.esm2.api import ESM2GenericConfig, ESM2Model
+from bionemo.esm2.data import tokenizer
+from bionemo.llm.data.collate import MLM_LOSS_IGNORE_INDEX
+from bionemo.llm.data.label2id_tokenizer import Label2IDTokenizer
+from bionemo.llm.data.types import BertSample
from bionemo.llm.model.biobert.model import BioBertOutput
from bionemo.llm.model.loss import BERTMLMLossWithReduction, PerTokenLossDict, SameSizeLossDict
from bionemo.llm.utils import iomixin_utils as iom
@@ -38,9 +44,25 @@
"MegatronConvNetHead",
"ESM2FineTuneTokenModel",
"ESM2FineTuneTokenConfig",
+ "InMemoryPerTokenValueDataset",
+ "ClassifierInput",
+ "Esm2FineTuneTokenOutput",
)
+class ClassifierInput(TypedDict):
+ """Used as input in the ClassifierLossReduction's forward method."""
+
+ labels: Tensor
+ loss_mask: Tensor
+
+
+class Esm2FineTuneTokenOutput(BioBertOutput):
+ """Inference output from ESM2FineTuneTokenModel."""
+
+ classification_output: Tensor
+
+
class ClassifierLossReduction(BERTMLMLossWithReduction):
"""A class for calculating the cross entropy loss of classification output.
@@ -48,7 +70,7 @@ class ClassifierLossReduction(BERTMLMLossWithReduction):
"""
def forward(
- self, batch: Dict[str, Tensor], forward_out: Dict[str, Tensor]
+ self, batch: ClassifierInput, forward_out: Esm2FineTuneTokenOutput
) -> Tuple[Tensor, PerTokenLossDict | SameSizeLossDict]:
"""Calculates the loss within a micro-batch. A micro-batch is a batch of data on a single GPU.
@@ -137,9 +159,9 @@ def __init__(self, config, *args, include_hiddens: bool = False, post_process: b
# if we are doing post process (eg pipeline last stage) then we need to add the output layers
self.classification_head = MegatronConvNetHead(config)
- def forward(self, *args, **kwargs) -> Tensor | BioBertOutput:
+ def forward(self, *args, **kwargs) -> Tensor | BioBertOutput | Esm2FineTuneTokenOutput:
"""Inference."""
- output = super().forward(*args, **kwargs)
+ output: Tensor | BioBertOutput | Esm2FineTuneTokenOutput = super().forward(*args, **kwargs)
# Stop early if we are not in post_process mode (for example if we are in the middle of model parallelism)
if not self.post_process:
return output # we are not at the last pipeline stage so just return what the parent has
@@ -181,3 +203,80 @@ class ESM2FineTuneTokenConfig(
def get_loss_reduction_class(self) -> Type[ClassifierLossReduction]:
"""The loss function type."""
return ClassifierLossReduction
+
+
+class InMemoryPerTokenValueDataset(Dataset):
+ """An in-memory dataset of labeled strings, which are tokenized on demand."""
+
+ def __init__(
+ self,
+ data: Sequence[Tuple[str, str]],
+ tokenizer: tokenizer.BioNeMoESMTokenizer = tokenizer.get_tokenizer(),
+ seed: int = np.random.SeedSequence().entropy, # type: ignore
+ ):
+ """Initializes a dataset for per-token classification fine-tuning.
+
+ This is an in-memory dataset that does not apply masking to the sequence.
+
+ Args:
+ data: A sequence of tuples containing the sequence and target data.
+ tokenizer: The tokenizer to use. Defaults to tokenizer.get_tokenizer().
+ seed: Random seed for reproducibility. This seed is mixed with the index of the sample to retrieve to
+ ensure that __getitem__ is deterministic, but can be random across different runs. If None, a random
+ seed is generated.
+ """
+ self.data = data
+ self.seed = seed
+ self._len = len(self.data)
+ self.tokenizer = tokenizer
+ label_tokenizer = Label2IDTokenizer()
+ self.label_tokenizer = label_tokenizer.build_vocab("CHE")
+ self.label_cls_eos_id = MLM_LOSS_IGNORE_INDEX
+
+ def __len__(self) -> int:
+ """Length of dataset."""
+ return self._len
+
+ def __getitem__(self, index: int) -> BertSample:
+ """Gets a BertSample associated to the supplied index."""
+ sequence, target = self.data[index]
+ tokenized_sequence = self._tokenize(sequence)
+ # Overall mask for a token being masked in some capacity - either mask token, random token, or left as-is
+ loss_mask = ~torch.isin(tokenized_sequence, torch.tensor(self.tokenizer.all_special_ids))
+ labels = self._tokenize_labels(target)
+
+ return {
+ "text": tokenized_sequence,
+ "types": torch.zeros_like(tokenized_sequence, dtype=torch.int64),
+ "attention_mask": torch.ones_like(tokenized_sequence, dtype=torch.int64),
+ "labels": labels,
+ "loss_mask": loss_mask,
+ "is_random": torch.zeros_like(tokenized_sequence, dtype=torch.int64),
+ }
+
+ def _tokenize_labels(self, labels_sequence: str) -> Tensor:
+ label_ids = torch.tensor(self.label_tokenizer.text_to_ids(labels_sequence))
+
+ # # for multi-label classification with BCEWithLogitsLoss
+ # tokenized_labels = torch.nn.functional.one_hot(label_ids, num_classes=self.label_tokenizer.vocab_size)
+ # cls_eos = torch.full((1, self.label_tokenizer.vocab_size), self.label_cls_eos_id, dtype=tokenized_labels.dtype)
+
+ # for multi-class (mutually exclusive) classification with CrossEntropyLoss
+ tokenized_labels = label_ids
+ cls_eos = torch.tensor([self.label_cls_eos_id], dtype=tokenized_labels.dtype)
+
+ # add cls / eos label ids with padding value -100 to have the same shape as tokenized_sequence
+ labels = torch.cat((cls_eos, tokenized_labels, cls_eos))
+ return labels
+
+ def _tokenize(self, sequence: str) -> Tensor:
+ """Tokenize a protein sequence.
+
+ Args:
+ sequence: The protein sequence.
+
+ Returns:
+ The tokenized sequence.
+ """
+ tensor = self.tokenizer.encode(sequence, add_special_tokens=True, return_tensors="pt")
+ return tensor.flatten() # type: ignore
diff --git a/sub-packages/bionemo-esm2/src/bionemo/esm2/model/finetune/train.py b/sub-packages/bionemo-esm2/src/bionemo/esm2/model/finetune/train.py
new file mode 100644
index 0000000000..638729e3f4
--- /dev/null
+++ b/sub-packages/bionemo-esm2/src/bionemo/esm2/model/finetune/train.py
@@ -0,0 +1,189 @@
+# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: LicenseRef-Apache2
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import tempfile
+from pathlib import Path
+from typing import Sequence, Tuple
+
+import lightning.pytorch as pl
+from lightning.pytorch.callbacks import Callback, RichModelSummary
+from lightning.pytorch.loggers import TensorBoardLogger
+from megatron.core.optimizer.optimizer_config import OptimizerConfig
+from nemo import lightning as nl
+from nemo.collections import llm as nllm
+from nemo.lightning import resume
+from nemo.lightning.nemo_logger import NeMoLogger
+from nemo.lightning.pytorch import callbacks as nl_callbacks
+from nemo.lightning.pytorch.callbacks.model_transform import ModelTransform
+from nemo.lightning.pytorch.callbacks.peft import PEFT
+from nemo.lightning.pytorch.optim.megatron import MegatronOptimizerModule
+
+from bionemo.core.data.load import load
+from bionemo.esm2.api import ESM2GenericConfig
+from bionemo.esm2.data.tokenizer import BioNeMoESMTokenizer, get_tokenizer
+from bionemo.esm2.model.finetune.datamodule import ESM2FineTuneDataModule
+from bionemo.esm2.model.finetune.finetune_regressor import ESM2FineTuneSeqConfig, InMemorySingleValueDataset
+from bionemo.llm.model.biobert.lightning import biobert_lightning_module
+
+
+__all__: Sequence[str] = ("train_model",)
+
+
+def train_model(
+ experiment_name: str,
+ experiment_dir: Path,
+ config: ESM2GenericConfig,
+ data_module: pl.LightningDataModule,
+ n_steps_train: int,
+ metric_tracker: Callback | None = None,
+ tokenizer: BioNeMoESMTokenizer = get_tokenizer(),
+ peft: PEFT | None = None,
+ _use_rich_model_summary: bool = True,
+) -> Tuple[Path, Callback | None, nl.Trainer]:
+ """Trains a BioNeMo ESM2 model using PyTorch Lightning.
+
+ Parameters:
+ experiment_name: The name of the experiment.
+ experiment_dir: The directory where the experiment will be saved.
+ config: The configuration for the ESM2 model.
+ data_module: The data module for training and validation.
+ n_steps_train: The number of training steps.
+ metric_tracker: Optional callback to track metrics
+ tokenizer: The tokenizer to use. Defaults to `get_tokenizer()`.
+ peft: The PEFT (Parameter-Efficient Fine-Tuning) module. Defaults to None.
+ _use_rich_model_summary: Whether to use the RichModelSummary callback, omitted in our test suite until
+ https://nvbugspro.nvidia.com/bug/4959776 is resolved. Defaults to True.
+
+ Returns:
+ A tuple containing the path to the saved checkpoint, a MetricTracker
+ object, and the PyTorch Lightning Trainer object.
+ """
+ checkpoint_callback = nl_callbacks.ModelCheckpoint(
+ save_last=True,
+ save_on_train_epoch_end=True,
+ monitor="reduced_train_loss", # TODO find out how to get val_loss logged and use "val_loss",
+ every_n_train_steps=n_steps_train // 2,
+ always_save_context=True, # Enables the .nemo file-like checkpointing where all IOMixins are under SerDe
+ )
+
+ # Setup the logger and train the model
+ nemo_logger = NeMoLogger(
+ log_dir=str(experiment_dir),
+ name=experiment_name,
+ tensorboard=TensorBoardLogger(save_dir=experiment_dir, name=experiment_name),
+ ckpt=checkpoint_callback,
+ )
+ # Needed so that the trainer can find an output directory for the profiler
+ # ckpt_path needs to be a string for SerDe
+ optimizer = MegatronOptimizerModule(
+ config=OptimizerConfig(
+ lr=5e-4,
+ optimizer="adam",
+ use_distributed_optimizer=True,
+ fp16=config.fp16,
+ bf16=config.bf16,
+ )
+ )
+ module = biobert_lightning_module(config=config, tokenizer=tokenizer, optimizer=optimizer, model_transform=peft)
+
+ strategy = nl.MegatronStrategy(
+ tensor_model_parallel_size=1,
+ pipeline_model_parallel_size=1,
+ ddp="megatron",
+ find_unused_parameters=True,
+ enable_nemo_ckpt_io=True,
+ )
+
+ if _use_rich_model_summary:
+ # RichModelSummary is not used in the test suite until https://nvbugspro.nvidia.com/bug/4959776 is resolved due
+ # to errors with serialization / deserialization.
+ callbacks: list[Callback] = [RichModelSummary(max_depth=4)]
+ else:
+ callbacks = []
+
+ if metric_tracker is not None:
+ callbacks.append(metric_tracker)
+ if peft is not None:
+ callbacks.append(
+ ModelTransform()
+ ) # Callback needed for PEFT fine-tuning using NeMo2, i.e. biobert_lightning_module(model_transform=peft).
+
+ trainer = nl.Trainer(
+ accelerator="gpu",
+ devices=1,
+ strategy=strategy,
+ limit_val_batches=2,
+ val_check_interval=n_steps_train // 2,
+ max_steps=n_steps_train,
+ num_nodes=1,
+ log_every_n_steps=n_steps_train // 2,
+ callbacks=callbacks,
+ plugins=nl.MegatronMixedPrecision(precision="bf16-mixed"),
+ )
+ nllm.train(
+ model=module,
+ data=data_module,
+ trainer=trainer,
+ log=nemo_logger,
+ resume=resume.AutoResume(
+ resume_if_exists=True, # Looks for the -last checkpoint to continue training.
+ resume_ignore_no_checkpoint=True, # When false this will throw an error with no existing checkpoint.
+ ),
+ )
+ ckpt_path = Path(checkpoint_callback.last_model_path.replace(".ckpt", ""))
+ return ckpt_path, metric_tracker, trainer
+
+
+if __name__ == "__main__":
+ # set the results directory
+ experiment_results_dir = tempfile.TemporaryDirectory().name
+
+ # create a List[Tuple] with (sequence, target) values
+ artificial_sequence_data = [
+ "TLILGWSDKLGSLLNQLAIANESLGGGTIAVMAERDKEDMELDIGKMEFDFKGTSVI",
+ "LYSGDHSTQGARFLRDLAENTGRAEYELLSLF",
+ "GRFNVWLGGNESKIRQVLKAVKEIGVSPTLFAVYEKN",
+ "DELTALGGLLHDIGKPVQRAGLYSGDHSTQGARFLRDLAENTGRAEYELLSLF",
+ "KLGSLLNQLAIANESLGGGTIAVMAERDKEDMELDIGKMEFDFKGTSVI",
+ "LFGAIGNAISAIHGQSAVEELVDAFVGGARISSAFPYSGDTYYLPKP",
+ "LGGLLHDIGKPVQRAGLYSGDHSTQGARFLRDLAENTGRAEYELLSLF",
+ "LYSGDHSTQGARFLRDLAENTGRAEYELLSLF",
+ "ISAIHGQSAVEELVDAFVGGARISSAFPYSGDTYYLPKP",
+ "SGSKASSDSQDANQCCTSCEDNAPATSYCVECSEPLCETCVEAHQRVKYTKDHTVRSTGPAKT",
+ ]
+ data = [(seq, len(seq) / 100.0) for seq in artificial_sequence_data]
+
+ # we are training and validating on the same dataset for simplicity
+ dataset = InMemorySingleValueDataset(data)
+ data_module = ESM2FineTuneDataModule(train_dataset=dataset, valid_dataset=dataset)
+
+ experiment_name = "finetune_regressor"
+ n_steps_train = 50
+ seed = 42
+
+ # To download a 650M pre-trained ESM2 model
+ pretrain_ckpt_path = load("esm2/650m:2.0")
+
+ config = ESM2FineTuneSeqConfig(initial_ckpt_path=str(pretrain_ckpt_path))
+
+ checkpoint, metrics, trainer = train_model(
+ experiment_name=experiment_name,
+ experiment_dir=Path(experiment_results_dir), # new checkpoint will land in a subdir of this
+ config=config, # same config as before since we are just continuing training
+ data_module=data_module,
+ n_steps_train=n_steps_train,
+ )
+ print(f"Experiment completed with checkpoint stored at {checkpoint}")
diff --git a/sub-packages/bionemo-esm2/src/bionemo/esm2/scripts/finetune_esm2.py b/sub-packages/bionemo-esm2/src/bionemo/esm2/scripts/finetune_esm2.py
deleted file mode 100644
index 1b35f169ff..0000000000
--- a/sub-packages/bionemo-esm2/src/bionemo/esm2/scripts/finetune_esm2.py
+++ /dev/null
@@ -1,635 +0,0 @@
-# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
-# SPDX-License-Identifier: LicenseRef-Apache2
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-import argparse
-from pathlib import Path
-from typing import Dict, List, Optional, Sequence, Tuple, Type, get_args
-
-from lightning.pytorch.callbacks import Callback, LearningRateMonitor, RichModelSummary
-from megatron.core.distributed import DistributedDataParallelConfig
-from megatron.core.optimizer import OptimizerConfig
-from nemo import lightning as nl
-from nemo.collections import llm
-from nemo.lightning import resume
-from nemo.lightning.pytorch import callbacks as nl_callbacks
-from nemo.lightning.pytorch.optim import MegatronOptimizerModule
-
-from bionemo.core.utils.dtypes import PrecisionTypes, get_autocast_dtype
-from bionemo.esm2.data.tokenizer import get_tokenizer
-from bionemo.esm2.model.finetune.datamodule import ESM2FineTuneDataModule
-from bionemo.esm2.model.finetune.dataset import (
- InMemoryPerTokenValueDataset,
- InMemoryProteinDataset,
- InMemorySingleValueDataset,
-)
-from bionemo.esm2.model.finetune.finetune_regressor import ESM2FineTuneSeqConfig
-from bionemo.esm2.model.finetune.finetune_token_classifier import ESM2FineTuneTokenConfig
-from bionemo.llm.model.biobert.lightning import biobert_lightning_module
-from bionemo.llm.model.biobert.model import BioBertConfig
-from bionemo.llm.utils.datamodule_utils import float_or_int_or_none, infer_global_batch_size
-from bionemo.llm.utils.logger_utils import WandbConfig, setup_nemo_lightning_logger
-
-
-__all__: Sequence[str] = ("train_model", "finetune_esm2_entrypoint", "get_parser")
-
-
-SUPPORTED_CONFIGS = {
- "ESM2FineTuneSeqConfig": ESM2FineTuneSeqConfig,
- "ESM2FineTuneTokenConfig": ESM2FineTuneTokenConfig,
-}
-
-SUPPORTED_DATASETS = {
- "InMemoryProteinDataset": InMemoryProteinDataset,
- "InMemorySingleValueDataset": InMemorySingleValueDataset,
- "InMemoryPerTokenValueDataset": InMemoryPerTokenValueDataset,
-}
-
-
-def train_model(
- train_data_path: Path,
- valid_data_path: Path,
- num_nodes: int,
- devices: int,
- min_seq_length: Optional[int],
- max_seq_length: int,
- result_dir: Path,
- num_steps: int,
- limit_val_batches: int,
- val_check_interval: int,
- log_every_n_steps: Optional[int],
- num_dataset_workers: int,
- lr: float,
- micro_batch_size: int,
- accumulate_grad_batches: int,
- experiment_name: str,
- resume_if_exists: bool,
- precision: PrecisionTypes,
- wandb_entity: Optional[str] = None,
- wandb_project: Optional[str] = None,
- wandb_offline: bool = False,
- wandb_tags: Optional[List[str]] = None,
- wandb_group: Optional[str] = None,
- wandb_id: Optional[str] = None,
- wandb_anonymous: Optional[bool] = False,
- wandb_log_model: bool = False,
- pipeline_model_parallel_size: int = 1,
- tensor_model_parallel_size: int = 1,
- create_tensorboard_logger: bool = False,
- restore_from_checkpoint_path: Optional[str] = None,
- save_last_checkpoint: bool = True,
- metric_to_monitor_for_checkpoints: str = "val_loss",
- save_top_k: int = 2,
- nsys_profiling: bool = False,
- nsys_start_step: int = 0,
- nsys_end_step: Optional[int] = None,
- nsys_ranks: List[int] = [0],
- dataset_class: Type[InMemoryProteinDataset] = InMemorySingleValueDataset,
- config_class: Type[BioBertConfig] = ESM2FineTuneSeqConfig,
- metric_tracker: Callback | None = None,
- overlap_grad_reduce: bool = True,
- overlap_param_gather: bool = True,
- average_in_collective: bool = True,
- grad_reduce_in_fp32: bool = False,
-) -> Tuple[Path, Callback | None, nl.Trainer]:
- """Train an ESM2 model on UR data.
-
- Args:
- train_data_path (Path): path to train CSV
- valid_data_path (Path): path to validation CSV
- num_nodes (int): Number of nodes to run on
- devices (int): number of devices
- min_seq_length (Optional[int]): minimum sequence length
- max_seq_length (int): maximum sequence length
- result_dir (Path): directory to store results, logs and checkpoints
- num_steps (int): number of steps to train the model for
- limit_val_batches (int): limit the number of validation global batches to this many
- val_check_interval (int): number of steps to periodically check the validation loss
- log_every_n_steps (Optional[int]): log every n steps
- num_dataset_workers (int): number of dataset workers
- lr (float): learning rate
- micro_batch_size (int): micro batch size, from this and parallelism settings we infer the global batch size
- accumulate_grad_batches (int): number of batches to accumulate gradients for
- experiment_name (str): experiment name, this is the name used for the wandb run, and the sub-directory of the
- result_dir that stores the logs and checkpoints.
- resume_if_exists (bool): attempt to resume if the checkpoint exists [FIXME @skothenhill this doesn't work yet]
- precision (PrecisionTypes): Precision type for training (e.g., float16, float32)
- wandb_entity (Optional[str]): The team posting this run (default: your username or your default team)
- wandb_project (Optional[str]): The name of the project to which this run will belong
- wandb_offline (bool): Run offline (data can be streamed later to wandb servers).
- wandb_tags (Optional[List[str]]): Tags associated with this run
- wandb_group (Optional[str]): A unique string shared by all runs in a given group
- wandb_id (Optional[str]): Sets the version, mainly used to resume a previous run
- wandb_anonymous (Optional[bool]): Enables or explicitly disables anonymous logging
- wandb_log_model (bool): Save checkpoints in wandb dir to upload on W&B servers
- pipeline_model_parallel_size (int): pipeline model parallel size
- tensor_model_parallel_size (int): tensor model parallel size
- create_tensorboard_logger (bool): create the tensorboard logger
- restore_from_checkpoint_path (Optional[str]): If set, restores the model from the directory passed in. Expects the
- checkpoint to be created by using the ModelCheckpoint class and always_save_context=True.
- save_last_checkpoint (bool): whether to save the last checkpoint
- metric_to_monitor_for_checkpoints (str): metric to monitor for checkpoints
- save_top_k (int): number of top checkpoints to save
- nsys_profiling (bool): whether to enable nsys profiling
- nsys_start_step (int): start step for nsys profiling
- nsys_end_step (Optional[int]): end step for nsys profiling
- nsys_ranks (List[int]): ranks for nsys profiling
- dataset_class (Type[InMemoryProteinDataset]): The dataset class for loading the data from a CSV file
- config_class (Type[BioBertConfig]): The config class for configuring the model using checkpoint provided
- metric_tracker: Optional callback to track metrics (used for testing)
- overlap_grad_reduce (bool): overlap gradient reduction
- overlap_param_gather (bool): overlap parameter gather
- average_in_collective (bool): average in collective
- grad_reduce_in_fp32 (bool): gradient reduction in fp32
- """
- # Create the result directory if it does not exist.
- result_dir.mkdir(parents=True, exist_ok=True)
-
- # Setup the strategy and trainer
- global_batch_size = infer_global_batch_size(
- micro_batch_size=micro_batch_size,
- num_nodes=num_nodes,
- devices=devices,
- accumulate_grad_batches=accumulate_grad_batches,
- tensor_model_parallel_size=tensor_model_parallel_size,
- pipeline_model_parallel_size=pipeline_model_parallel_size,
- )
-
- strategy = nl.MegatronStrategy(
- tensor_model_parallel_size=tensor_model_parallel_size,
- pipeline_model_parallel_size=pipeline_model_parallel_size,
- ddp=DistributedDataParallelConfig(
- check_for_nan_in_grad=True,
- overlap_grad_reduce=overlap_grad_reduce,
- overlap_param_gather=overlap_param_gather,
- average_in_collective=average_in_collective,
- grad_reduce_in_fp32=grad_reduce_in_fp32,
- use_distributed_optimizer=True,
- ),
- find_unused_parameters=True,
- gradient_as_bucket_view=True,
- ckpt_include_optimizer=True,
- ckpt_async_save=True,
- ckpt_parallel_load=True,
- )
-
- # for wandb integration
- # Please refer to https://pytorch-lightning.readthedocs.io/en/0.7.6/api/lightning.pytorch.loggers.html"
- wandb_config: Optional[WandbConfig] = (
- None
- if wandb_project is None
- else WandbConfig(
- offline=wandb_offline,
- project=wandb_project,
- entity=wandb_entity,
- tags=wandb_tags,
- group=wandb_group,
- id=wandb_id,
- anonymous=wandb_anonymous,
- log_model=wandb_log_model,
- )
- )
-
- callbacks = [
- RichModelSummary(max_depth=4),
- LearningRateMonitor(),
- nl_callbacks.PreemptionCallback(),
- ]
- if metric_tracker is not None:
- callbacks.append(metric_tracker)
- if nsys_profiling:
- if nsys_end_step is None:
- nsys_end_step = num_steps
- callbacks.append(
- nl_callbacks.NsysCallback(
- start_step=nsys_start_step, end_step=nsys_end_step, ranks=nsys_ranks, gen_shape=True
- )
- )
-
- trainer = nl.Trainer(
- devices=devices,
- max_steps=num_steps,
- accelerator="gpu",
- strategy=strategy,
- limit_val_batches=limit_val_batches, # This controls upsampling and downsampling
- val_check_interval=val_check_interval,
- log_every_n_steps=log_every_n_steps,
- num_nodes=num_nodes,
- callbacks=callbacks,
- plugins=nl.MegatronMixedPrecision(
- precision=precision,
- params_dtype=get_autocast_dtype(precision),
- pipeline_dtype=get_autocast_dtype(precision),
- grad_reduce_in_fp32=grad_reduce_in_fp32,
- autocast_enabled=False,
- ),
- )
-
- tokenizer = get_tokenizer()
-
- # Initialize the data module.
- train_dataset = dataset_class.from_csv(train_data_path)
- valid_dataset = dataset_class.from_csv(valid_data_path)
-
- data_module = ESM2FineTuneDataModule(
- train_dataset=train_dataset,
- valid_dataset=valid_dataset,
- global_batch_size=global_batch_size,
- micro_batch_size=micro_batch_size,
- min_seq_length=min_seq_length,
- max_seq_length=max_seq_length,
- num_workers=num_dataset_workers,
- tokenizer=tokenizer,
- )
- # Configure the model
- config = config_class(
- params_dtype=get_autocast_dtype(precision),
- pipeline_dtype=get_autocast_dtype(precision),
- autocast_dtype=get_autocast_dtype(precision), # setting this speeds things up a lot
- tensor_model_parallel_size=tensor_model_parallel_size,
- pipeline_model_parallel_size=pipeline_model_parallel_size,
- initial_ckpt_path=str(restore_from_checkpoint_path),
- # initial_ckpt_skip_keys_with_these_prefixes=[], # load everything from the checkpoint.
- )
-
- optimizer = MegatronOptimizerModule(
- config=OptimizerConfig(
- lr=lr,
- optimizer="adam", # fused_adam not supported
- use_distributed_optimizer=True,
- weight_decay=0.01,
- adam_beta1=0.9,
- adam_beta2=0.98,
- )
- )
-
- module = biobert_lightning_module(config=config, tokenizer=tokenizer, optimizer=optimizer)
-
- # Configure our custom Checkpointer
- checkpoint_callback = nl_callbacks.ModelCheckpoint(
- save_last=save_last_checkpoint,
- monitor=metric_to_monitor_for_checkpoints, # "val_loss",
- save_top_k=save_top_k,
- every_n_train_steps=val_check_interval,
- always_save_context=True, # Enables the .nemo file-like checkpointing where all IOMixins are under SerDe
- filename="checkpoint-{step}-{consumed_samples}", # Including step and consumed_samples in the checkpoint filename prevents duplicate filenames and bugs related to this.
- )
-
- # Setup the logger and train the model
- nemo_logger = setup_nemo_lightning_logger(
- root_dir=result_dir,
- name=experiment_name,
- initialize_tensorboard_logger=create_tensorboard_logger,
- wandb_config=wandb_config,
- ckpt_callback=checkpoint_callback,
- )
-
- llm.train(
- model=module,
- data=data_module,
- trainer=trainer,
- log=nemo_logger,
- resume=resume.AutoResume(
- resume_if_exists=resume_if_exists, # Looks for the -last checkpoint to continue training.
- resume_ignore_no_checkpoint=True, # When false this will throw an error with no existing checkpoint.
- ),
- )
- ckpt_path = Path(checkpoint_callback.last_model_path.replace(".ckpt", ""))
- return ckpt_path, metric_tracker, trainer
-
-
-def finetune_esm2_entrypoint():
- """Entrypoint for running ESM2 finetuning."""
- # 1. get arguments
- parser = get_parser()
- args = parser.parse_args()
- # 2. Call pretrain with args
- train_model(
- train_data_path=args.train_data_path,
- valid_data_path=args.valid_data_path,
- num_nodes=args.num_nodes,
- devices=args.num_gpus,
- min_seq_length=args.min_seq_length,
- max_seq_length=args.max_seq_length,
- result_dir=args.result_dir,
- wandb_entity=args.wandb_entity,
- wandb_project=args.wandb_project,
- wandb_tags=args.wandb_tags,
- wandb_group=args.wandb_group,
- wandb_id=args.wandb_id,
- wandb_anonymous=args.wandb_anonymous,
- wandb_log_model=args.wandb_log_model,
- wandb_offline=args.wandb_offline,
- num_steps=args.num_steps,
- limit_val_batches=args.limit_val_batches,
- val_check_interval=args.val_check_interval,
- log_every_n_steps=args.log_every_n_steps,
- num_dataset_workers=args.num_dataset_workers,
- lr=args.lr,
- micro_batch_size=args.micro_batch_size,
- pipeline_model_parallel_size=args.pipeline_model_parallel_size,
- tensor_model_parallel_size=args.tensor_model_parallel_size,
- accumulate_grad_batches=args.accumulate_grad_batches,
- precision=args.precision,
- experiment_name=args.experiment_name,
- resume_if_exists=args.resume_if_exists,
- restore_from_checkpoint_path=args.restore_from_checkpoint_path,
- save_last_checkpoint=args.save_last_checkpoint,
- metric_to_monitor_for_checkpoints=args.metric_to_monitor_for_checkpoints,
- save_top_k=args.save_top_k,
- nsys_profiling=args.nsys_profiling,
- nsys_start_step=args.nsys_start_step,
- nsys_end_step=args.nsys_end_step,
- nsys_ranks=args.nsys_ranks,
- dataset_class=args.dataset_class,
- config_class=args.config_class,
- overlap_grad_reduce=not args.no_overlap_grad_reduce,
- overlap_param_gather=not args.no_overlap_param_gather,
- average_in_collective=not args.no_average_in_collective,
- grad_reduce_in_fp32=args.grad_reduce_in_fp32,
- )
-
-
-def get_parser():
- """Return the cli parser for this tool."""
- # TODO migrate to hydra config
- # Parse the arguments and pull them out into local variables for ease of future refactor to a
- # config management system.
- parser = argparse.ArgumentParser(description="Pretrain ESM2 with UR data.")
- parser.add_argument(
- "--train-data-path",
- type=Path,
- required=True,
- help="Path to the train data CSV file",
- )
- parser.add_argument(
- "--valid-data-path",
- type=Path,
- required=True,
- help="Path to the valid data CSV file",
- )
- parser.add_argument(
- "--precision",
- type=str,
- choices=get_args(PrecisionTypes),
- required=False,
- default="bf16-mixed",
- help="Precision type to use for training.",
- )
- parser.add_argument(
- "--lr",
- type=float,
- required=False,
- default=4e-4,
- help="Learning rate for training. Default is 4e-4",
- )
- parser.add_argument(
- "--create-tensorboard-logger", action="store_true", default=False, help="Create a tensorboard logger."
- )
- # FIXME (@skothenhill) figure out how checkpointing and resumption should work with the new nemo trainer
- parser.add_argument(
- "--resume-if-exists", action="store_true", default=False, help="Resume training if a checkpoint exists."
- )
- parser.add_argument(
- "--result-dir", type=Path, required=False, default=Path("./results"), help="Path to the result directory."
- )
- parser.add_argument("--experiment-name", type=str, required=False, default="esm2", help="Name of the experiment.")
-
- parser.add_argument("--wandb-entity", type=str, default=None, help="The team posting this run")
- parser.add_argument("--wandb-project", type=str, default=None, help="Wandb project name ")
- parser.add_argument("--wandb-tags", nargs="+", type=str, default=None, help="Tags associated with this run")
- parser.add_argument(
- "--wandb-group", type=str, default=None, help="A unique string shared by all runs in a given group"
- )
- parser.add_argument(
- "--wandb-id", type=str, default=None, help="Sets the version, mainly used to resume a previous run"
- )
- parser.add_argument(
- "--wandb-anonymous", action="store_true", help="Enable or explicitly disable anonymous logging"
- )
- parser.add_argument(
- "--wandb-log-model", action="store_true", help="Save checkpoints in wandb dir to upload on W&B servers"
- )
- parser.add_argument("--wandb-offline", action="store_true", help="Use wandb in offline mode")
- parser.add_argument(
- "--num-gpus",
- type=int,
- required=False,
- default=1,
- help="Number of GPUs to use for training. Default is 1.",
- )
- parser.add_argument(
- "--num-nodes",
- type=int,
- required=False,
- default=1,
- help="Number of nodes to use for training. Default is 1.",
- )
- parser.add_argument(
- "--num-steps",
- type=int,
- required=False,
- default=500000,
- help="Number of steps to use for training. Default is 500000.",
- )
- parser.add_argument(
- "--num-dataset-workers",
- type=int,
- required=False,
- default=1,
- help="Number of workers to use for training. Default is 1.",
- )
- parser.add_argument(
- "--val-check-interval",
- type=int,
- required=False,
- default=10000,
- help="Number of steps between validation. Default is 10000.",
- )
- parser.add_argument(
- "--log-every-n-steps",
- type=int,
- required=False,
- help="Number of steps between logging. Default is 50.",
- )
- parser.add_argument(
- "--min-seq-length",
- type=float_or_int_or_none,
- required=False,
- default=1024,
- help="Minimum sequence length. Sampled will be padded if less than this value. Set 'None' to unset minimum.",
- )
- parser.add_argument(
- "--max-seq-length",
- type=int,
- required=False,
- default=1024,
- help="Maximum sequence length. Samples will be truncated if exceeds this value.",
- )
- parser.add_argument(
- "--limit-val-batches",
- type=float_or_int_or_none,
- required=False,
- default=2,
- help="Number of global batches used for validation if int. Fraction of validation dataset if float. Default is 2.",
- )
- parser.add_argument(
- "--micro-batch-size",
- type=int,
- required=False,
- default=64,
- help="Micro-batch size. Global batch size is inferred from this.",
- )
- parser.add_argument(
- "--pipeline-model-parallel-size",
- type=int,
- required=False,
- default=1,
- help="Pipeline model parallel size. Default is 1.",
- )
- parser.add_argument(
- "--tensor-model-parallel-size",
- type=int,
- required=False,
- default=1,
- help="Tensor model parallel size. Default is 1.",
- )
- parser.add_argument(
- "--accumulate-grad-batches",
- type=int,
- required=False,
- default=1,
- help="Gradient accumulation steps. Global batch size is inferred from this.",
- )
- parser.add_argument(
- "--save-last-checkpoint",
- action="store_true",
- default=True,
- help="Save the last checkpoint.",
- )
- parser.add_argument(
- "--metric-to-monitor-for-checkpoints",
- type=str,
- required=False,
- default="val_loss",
- help="The metric to monitor for checkpointing.",
- )
- parser.add_argument(
- "--save-top-k",
- type=int,
- required=False,
- default=2,
- help="Save the top k checkpoints.",
- )
- parser.add_argument(
- "--restore-from-checkpoint-path",
- type=Path,
- required=False,
- default=None,
- help="Path to the checkpoint directory to restore from. Will override `--resume-if-exists` when set.",
- )
- parser.add_argument(
- "--nsys-profiling",
- action="store_true",
- default=False,
- help="Enable targeted `nsys` profiling on the training loop for a defined step range. To actually get profiling output you must run the whole program with `nsys`. For example: "
- " `nsys profile -s none -o output_report_name -t cuda,nvtx --force-overwrite true --capture-range=cudaProfilerApi --capture-range-end=stop [regular python command here]`",
- )
- # start, end, rank
- parser.add_argument(
- "--nsys-start-step",
- type=int,
- required=False,
- default=0,
- help="Start nsys profiling after this step.",
- )
- parser.add_argument(
- "--nsys-end-step",
- type=int,
- required=False,
- help="End nsys profiling after this step.",
- )
- # rank as list of integers
- parser.add_argument(
- "--nsys-ranks",
- type=int,
- nargs="+",
- required=False,
- default=[0],
- help="Enable nsys profiling for these ranks.",
- )
- # DDP config
- parser.add_argument(
- "--no-overlap-grad-reduce",
- action="store_true",
- default=False,
- )
- parser.add_argument(
- "--no-overlap-param-gather",
- action="store_true",
- default=False,
- )
- parser.add_argument(
- "--no-average-in-collective",
- action="store_true",
- default=False,
- )
- parser.add_argument(
- "--grad-reduce-in-fp32",
- action="store_true",
- default=False,
- )
-
- config_class_options: Dict[str, Type[BioBertConfig]] = SUPPORTED_CONFIGS
-
- def config_class_type(desc: str) -> Type[BioBertConfig]:
- try:
- return config_class_options[desc]
- except KeyError:
- raise argparse.ArgumentTypeError(
- f"Do not recognize key {desc}, valid options are: {config_class_options.keys()}"
- )
-
- parser.add_argument(
- "--config-class",
- type=config_class_type,
- default=ESM2FineTuneSeqConfig,
- help="Model configs link model classes with losses, and handle model initialization (including from a prior "
- "checkpoint). This is how you can fine-tune a model. First train with one config class that points to one model "
- "class and loss, then implement and provide an alternative config class that points to a variant of that model "
- "and alternative loss. In the future this script should also provide similar support for picking different data "
- f"modules for fine-tuning with different data types. Choices: {config_class_options.keys()}",
- )
-
- dataset_class_options: Dict[str, Type[InMemoryProteinDataset]] = SUPPORTED_DATASETS
-
- def dataset_class_type(desc: str) -> Type[InMemoryProteinDataset]:
- try:
- return dataset_class_options[desc]
- except KeyError:
- raise argparse.ArgumentTypeError(
- f"Do not recognize key {desc}, valid options are: {dataset_class_options.keys()}"
- )
-
- parser.add_argument(
- "--dataset-class",
- type=dataset_class_type,
- default=InMemorySingleValueDataset,
- help=f"Dataset class name for finetuning. Choices: {config_class_options.keys()}",
- )
- return parser
-
-
-if __name__ == "__main__":
- finetune_esm2_entrypoint()
diff --git a/sub-packages/bionemo-esm2/src/bionemo/esm2/scripts/infer_esm2.py b/sub-packages/bionemo-esm2/src/bionemo/esm2/scripts/infer_esm2.py
index 9531165c13..bdfaa4fe6b 100644
--- a/sub-packages/bionemo-esm2/src/bionemo/esm2/scripts/infer_esm2.py
+++ b/sub-packages/bionemo-esm2/src/bionemo/esm2/scripts/infer_esm2.py
@@ -23,8 +23,7 @@
from bionemo.core.utils.dtypes import PrecisionTypes, get_autocast_dtype
from bionemo.esm2.api import ESM2Config
from bionemo.esm2.data.tokenizer import get_tokenizer
-from bionemo.esm2.model.finetune.datamodule import ESM2FineTuneDataModule
-from bionemo.esm2.model.finetune.dataset import InMemoryProteinDataset
+from bionemo.esm2.model.finetune.datamodule import ESM2FineTuneDataModule, InMemoryCSVDataset
from bionemo.esm2.model.finetune.finetune_regressor import ESM2FineTuneSeqConfig
from bionemo.esm2.model.finetune.finetune_token_classifier import ESM2FineTuneTokenConfig
from bionemo.llm.model.biobert.lightning import biobert_lightning_module
@@ -111,7 +110,7 @@ def infer_model(
plugins=nl.MegatronMixedPrecision(precision=precision),
)
- dataset = InMemoryProteinDataset.from_csv(data_path)
+ dataset = InMemoryCSVDataset(data_path=data_path)
datamodule = ESM2FineTuneDataModule(
predict_dataset=dataset,
micro_batch_size=micro_batch_size,
diff --git a/sub-packages/bionemo-esm2/tests/bionemo/esm2/conftest.py b/sub-packages/bionemo-esm2/tests/bionemo/esm2/conftest.py
index fc896b54a2..2e0c7a0f7f 100644
--- a/sub-packages/bionemo-esm2/tests/bionemo/esm2/conftest.py
+++ b/sub-packages/bionemo-esm2/tests/bionemo/esm2/conftest.py
@@ -87,14 +87,3 @@ def dummy_data_single_value_regression_ft(dummy_data_per_token_classification_ft
"""
data = [(seq, len(seq) / 100.0) for seq, _ in dummy_data_per_token_classification_ft]
return data
-
-
-@pytest.fixture
-def dummy_protein_sequences(dummy_data_per_token_classification_ft):
- """Fixture providing dummy data for per-token classification fine-tuning.
-
- Returns:
- list: A list of dummy data for per-token classification fine-tuning.
- """
- data = [seq for seq, _ in dummy_data_per_token_classification_ft]
- return data
diff --git a/sub-packages/bionemo-esm2/tests/bionemo/esm2/model/finetune/test_datamodule.py b/sub-packages/bionemo-esm2/tests/bionemo/esm2/model/finetune/test_datamodule.py
deleted file mode 100644
index c1ccfb5284..0000000000
--- a/sub-packages/bionemo-esm2/tests/bionemo/esm2/model/finetune/test_datamodule.py
+++ /dev/null
@@ -1,81 +0,0 @@
-# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
-# SPDX-License-Identifier: LicenseRef-Apache2
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-
-import pandas as pd
-import pytest
-from torch.utils.data import DataLoader
-
-from bionemo.esm2.model.finetune.datamodule import ESM2FineTuneDataModule
-from bionemo.esm2.model.finetune.dataset import InMemoryProteinDataset
-
-
-@pytest.fixture
-def dummy_protein_csv(tmp_path, dummy_protein_sequences):
- """Create a mock protein dataset."""
- csv_file = tmp_path / "protein_dataset.csv"
- # Create a DataFrame
- df = pd.DataFrame(dummy_protein_sequences, columns=["sequences"])
-
- # Save the DataFrame to a CSV file
- df.to_csv(csv_file, index=False)
- return csv_file
-
-
-@pytest.fixture
-def dataset(dummy_protein_csv):
- return InMemoryProteinDataset.from_csv(dummy_protein_csv)
-
-
-@pytest.fixture
-def data_module(dataset):
- return ESM2FineTuneDataModule(predict_dataset=dataset)
-
-
-def test_in_memory_csv_dataset(dataset):
- assert len(dataset) > 0
- sample = dataset[0]
- assert isinstance(sample, dict)
- assert "text" in sample
- assert "labels" in sample
-
-
-def test_esm2_fine_tune_data_module_init(data_module):
- assert data_module.train_dataset is None
- assert data_module.valid_dataset is None
- assert data_module.predict_dataset is not None
-
-
-def test_esm2_fine_tune_data_module_predict_dataloader(data_module):
- predict_dataloader = data_module.predict_dataloader()
- assert isinstance(predict_dataloader, DataLoader)
- batch = next(iter(predict_dataloader))
- assert isinstance(batch, dict)
- assert "text" in batch
-
-
-def test_esm2_fine_tune_data_module_setup(data_module):
- with pytest.raises(RuntimeError):
- data_module.setup("fit")
-
-
-def test_esm2_fine_tune_data_module_train_dataloader(data_module):
- with pytest.raises(AttributeError):
- data_module.train_dataloader()
-
-
-def test_esm2_fine_tune_data_module_val_dataloader(data_module):
- with pytest.raises(AttributeError):
- data_module.val_dataloader()
diff --git a/sub-packages/bionemo-esm2/tests/bionemo/esm2/model/finetune/test_dataset.py b/sub-packages/bionemo-esm2/tests/bionemo/esm2/model/finetune/test_dataset.py
deleted file mode 100644
index afcd53feab..0000000000
--- a/sub-packages/bionemo-esm2/tests/bionemo/esm2/model/finetune/test_dataset.py
+++ /dev/null
@@ -1,168 +0,0 @@
-# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
-# SPDX-License-Identifier: LicenseRef-Apache2
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-
-import pandas as pd
-import pytest
-import torch
-from torch import Tensor
-
-from bionemo.esm2.model.finetune.dataset import (
- InMemoryPerTokenValueDataset,
- InMemoryProteinDataset,
- InMemorySingleValueDataset,
-)
-from bionemo.llm.data.collate import MLM_LOSS_IGNORE_INDEX
-from bionemo.llm.data.label2id_tokenizer import Label2IDTokenizer
-
-
-def data_to_csv(data, tmp_path, with_label=True):
- """Create a mock protein dataset."""
- csv_file = tmp_path / "protein_dataset.csv"
- # Create a DataFrame
- df = pd.DataFrame(data, columns=["sequences", "labels"] if with_label else ["sequences"])
-
- # Save the DataFrame to a CSV file
- df.to_csv(csv_file, index=False)
- return csv_file
-
-
-@pytest.fixture
-def dataset_no_labels(dummy_protein_sequences, tmp_path):
- csv_path = data_to_csv(dummy_protein_sequences, tmp_path, with_label=False)
- return InMemoryProteinDataset.from_csv(csv_path)
-
-
-@pytest.fixture
-def dataset_regression_labels(dummy_data_single_value_regression_ft, tmp_path):
- csv_path = data_to_csv(dummy_data_single_value_regression_ft, tmp_path, with_label=True)
- return InMemorySingleValueDataset.from_csv(csv_path)
-
-
-@pytest.fixture
-def dataset_per_token_classification_labels(dummy_data_per_token_classification_ft, tmp_path):
- csv_path = data_to_csv(dummy_data_per_token_classification_ft, tmp_path, with_label=True)
- return InMemoryPerTokenValueDataset.from_csv(csv_path)
-
-
-def test_in_memory_protein_dataset_length_no_labels(dataset_no_labels, dummy_protein_sequences):
- assert len(dataset_no_labels) == len(dummy_protein_sequences)
-
-
-def test_in_memory_protein_dataset_length_with_regression_labels(
- dataset_regression_labels, dummy_data_single_value_regression_ft
-):
- assert len(dataset_regression_labels) == len(dummy_data_single_value_regression_ft)
-
-
-def test_in_memory_protein_dataset_length_with_class_labels(
- dataset_per_token_classification_labels, dummy_data_per_token_classification_ft
-):
- assert len(dataset_per_token_classification_labels) == len(dummy_data_per_token_classification_ft)
-
-
-def test_in_memory_protein_dataset_getitem_no_labels(dataset_no_labels):
- sample = dataset_no_labels[0]
- assert isinstance(sample, dict)
- assert "text" in sample
- assert "labels" in sample
- assert isinstance(sample["text"], Tensor)
- assert isinstance(sample["labels"], Tensor)
-
-
-def test_in_memory_protein_dataset_getitem_with_regression_labels(dataset_regression_labels):
- assert isinstance(dataset_regression_labels, InMemoryProteinDataset)
- sample = dataset_regression_labels[0]
- assert isinstance(sample, dict)
- assert "text" in sample
- assert "labels" in sample
- assert isinstance(sample["text"], Tensor)
- assert isinstance(sample["labels"], Tensor)
- assert sample["labels"].dtype == torch.float
-
-
-def test_in_memory_protein_dataset_getitem_with_class_labels(dataset_per_token_classification_labels):
- assert isinstance(dataset_per_token_classification_labels, InMemoryProteinDataset)
- assert isinstance(dataset_per_token_classification_labels.label_tokenizer, Label2IDTokenizer)
- assert dataset_per_token_classification_labels.label_cls_eos_id == MLM_LOSS_IGNORE_INDEX
-
- sample = dataset_per_token_classification_labels[0]
- assert isinstance(sample, dict)
- assert "text" in sample
- assert "labels" in sample
- assert isinstance(sample["text"], Tensor)
- assert isinstance(sample["labels"], Tensor)
- assert sample["labels"].dtype == torch.int64
-
-
-def test_in_memory_protein_dataset_tokenization(dataset_no_labels):
- sample = dataset_no_labels[0]
- tokenized_sequence = sample["text"]
- assert isinstance(tokenized_sequence, Tensor)
- assert tokenized_sequence.ndim == 1 # Ensure it's flattened.
-
-
-def test_transofrm_classification_label(
- dataset_per_token_classification_labels, dummy_data_per_token_classification_ft
-):
- pre_transfrom = dummy_data_per_token_classification_ft[0][1]
- label_ids = torch.tensor(dataset_per_token_classification_labels.label_tokenizer.text_to_ids(pre_transfrom))
- cls_eos = torch.tensor([dataset_per_token_classification_labels.label_cls_eos_id])
- post_transform = torch.cat((cls_eos, label_ids, cls_eos))
-
- assert torch.equal(dataset_per_token_classification_labels.transform_label(pre_transfrom), post_transform)
-
-
-def test_transofrm_regression_label(dataset_regression_labels):
- """Ensure labels are transformed correctly."""
- transformed_label = dataset_regression_labels.transform_label(1.0)
- assert isinstance(transformed_label, Tensor)
- assert transformed_label.dtype == torch.float
-
-
-def test_in_memory_protein_dataset_no_labels_fallback(dataset_no_labels):
- """Ensure the dataset works even when labels are missing."""
- sample = dataset_no_labels[0]
- assert isinstance(sample, dict)
- assert "labels" in sample
- assert isinstance(sample["labels"], Tensor)
-
-
-def test_in_memory_protein_dataset_invalid_index(dataset_no_labels):
- """Test if out-of-range index raises an error."""
- with pytest.raises(KeyError):
- _ = dataset_no_labels[100]
-
-
-def test_in_memory_protein_dataset_missing_sequences_column(tmp_path):
- """Test behavior when the CSV file is empty."""
- csv_file = tmp_path / "invalid.csv"
- pd.DataFrame({"wrong_column": ["MKTFFS"]}).to_csv(csv_file, index=False)
- with pytest.raises(KeyError):
- _ = InMemoryProteinDataset.from_csv(csv_file)
-
-
-def test_in_memory_protein_dataset_special_tokens_masking(dataset_no_labels):
- """Ensure loss mask correctly handles special tokens."""
- sample = dataset_no_labels[0]
- assert "loss_mask" in sample
- assert isinstance(sample["loss_mask"], Tensor)
- assert sample["loss_mask"].dtype == torch.bool
-
-
-def test_in_memory_protein_dataset_non_existent_file():
- """Ensure proper error handling for missing files."""
- with pytest.raises(FileNotFoundError):
- InMemoryProteinDataset.from_csv("non_existent_file.csv")
diff --git a/sub-packages/bionemo-esm2/tests/bionemo/esm2/model/finetune/test_finetune.py b/sub-packages/bionemo-esm2/tests/bionemo/esm2/model/finetune/test_finetune.py
new file mode 100644
index 0000000000..9c49d70c42
--- /dev/null
+++ b/sub-packages/bionemo-esm2/tests/bionemo/esm2/model/finetune/test_finetune.py
@@ -0,0 +1,143 @@
+# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: LicenseRef-Apache2
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import pytest
+from nemo.lightning import io
+
+from bionemo.core.data.load import load
+from bionemo.esm2.model.finetune.datamodule import ESM2FineTuneDataModule
+from bionemo.esm2.model.finetune.finetune_regressor import (
+ ESM2FineTuneSeqConfig,
+ InMemorySingleValueDataset,
+)
+from bionemo.esm2.model.finetune.finetune_token_classifier import (
+ ESM2FineTuneTokenConfig,
+ InMemoryPerTokenValueDataset,
+)
+from bionemo.esm2.model.finetune.peft import ESM2LoRA
+from bionemo.esm2.model.finetune.train import train_model
+from bionemo.testing import megatron_parallel_state_utils
+from bionemo.testing.callbacks import MetricTracker
+
+
+# To download a 8M internally pre-trained ESM2 model
+pretrain_ckpt_path = load("esm2/nv_8m:2.0")
+
+
+@pytest.mark.needs_gpu
+@pytest.mark.parametrize("with_peft", [True, False])
+def test_esm2_finetune_token_classifier(
+ tmp_path,
+ tokenizer,
+ dummy_data_per_token_classification_ft,
+ with_peft: bool,
+ n_steps_train: int = 50,
+ seed: int = 42,
+):
+ if with_peft:
+ pytest.xfail("FIXME PEFT fine-tuning not supported with fusions active")
+
+ with megatron_parallel_state_utils.distributed_model_parallel_state(seed):
+ if with_peft:
+ peft = ESM2LoRA()
+ else:
+ peft = None
+ esm2_finetune_config = ESM2FineTuneTokenConfig(initial_ckpt_path=str(pretrain_ckpt_path))
+ dataset = InMemoryPerTokenValueDataset(dummy_data_per_token_classification_ft, seed=seed)
+ finetune_data_module = ESM2FineTuneDataModule(dataset, dataset)
+ simple_ft_checkpoint, simple_ft_metrics, trainer = train_model(
+ experiment_name="finetune_new_head",
+ experiment_dir=tmp_path / "finetune_new_head", # new checkpoint will land in a subdir of this
+ config=esm2_finetune_config, # same config as before since we are just continuing training
+ data_module=finetune_data_module,
+ n_steps_train=n_steps_train,
+ metric_tracker=MetricTracker(metrics_to_track_val=["loss"], metrics_to_track_train=["loss"]),
+ tokenizer=tokenizer,
+ peft=peft,
+ _use_rich_model_summary=False,
+ )
+
+ weights_ckpt = simple_ft_checkpoint / "weights"
+ assert weights_ckpt.exists()
+ assert weights_ckpt.is_dir()
+ assert io.is_distributed_ckpt(weights_ckpt)
+ assert simple_ft_metrics.collection_train["loss"][0] > simple_ft_metrics.collection_train["loss"][-1]
+
+ if with_peft:
+ assert trainer.model.model_transform is not None
+ model = trainer.model[0].module.module.module
+ assert all(not p.requires_grad for p in model.embedding.parameters())
+ assert all(not p.requires_grad for name, p in model.encoder.named_parameters() if "adapter" not in name)
+ assert all(p.requires_grad for name, p in model.encoder.named_parameters() if "adapter" in name)
+ assert all(p.requires_grad for p in model.classification_head.parameters())
+ else:
+ encoder_requires_grad = [
+ p.requires_grad for name, p in trainer.model.named_parameters() if "classification_head" not in name
+ ]
+ assert not all(encoder_requires_grad), "Pretrained model is not fully frozen during fine-tuning"
+
+
+@pytest.mark.needs_gpu
+@pytest.mark.parametrize("with_peft", [True, False])
+def test_esm2_finetune_regressor(
+ tmp_path,
+ tokenizer,
+ dummy_data_single_value_regression_ft,
+ with_peft: bool,
+ n_steps_train: int = 50,
+ seed: int = 42,
+):
+ if with_peft:
+ pytest.xfail("FIXME PEFT fine-tuning not supported")
+
+ with megatron_parallel_state_utils.distributed_model_parallel_state(seed):
+ if with_peft:
+ peft = ESM2LoRA()
+ else:
+ peft = None
+ esm2_regression_finetune_config = ESM2FineTuneSeqConfig(initial_ckpt_path=str(pretrain_ckpt_path))
+ dataset = InMemorySingleValueDataset(dummy_data_single_value_regression_ft, seed=seed)
+ finetune_data_module = ESM2FineTuneDataModule(dataset, dataset)
+ simple_ft_checkpoint, simple_ft_metrics, trainer = train_model(
+ experiment_name="finetune_new_head_regression",
+ experiment_dir=tmp_path / "finetune_new_head_regression", # new checkpoint will land in a subdir of this
+ config=esm2_regression_finetune_config, # same config as before since we are just continuing training
+ data_module=finetune_data_module,
+ n_steps_train=n_steps_train,
+ metric_tracker=MetricTracker(metrics_to_track_val=["loss"], metrics_to_track_train=["loss"]),
+ tokenizer=tokenizer,
+ peft=peft,
+ _use_rich_model_summary=False,
+ )
+
+ weights_ckpt = simple_ft_checkpoint / "weights"
+ assert weights_ckpt.exists()
+ assert weights_ckpt.is_dir()
+ assert io.is_distributed_ckpt(weights_ckpt)
+ assert simple_ft_metrics.collection_train["loss"][0] > simple_ft_metrics.collection_train["loss"][-1]
+
+ if with_peft:
+ assert trainer.model.model_transform is not None
+ model = trainer.model[0].module.module.module
+ assert all(not p.requires_grad for p in model.embedding.parameters())
+ assert all(not p.requires_grad for name, p in model.encoder.named_parameters() if "adapter" not in name)
+ assert all(p.requires_grad for name, p in model.encoder.named_parameters() if "adapter" in name)
+ assert all(p.requires_grad for p in model.regression_head.parameters())
+ else:
+ encoder_requires_grad = [
+ p.requires_grad for name, p in trainer.model.named_parameters() if "regression_head" not in name
+ ]
+ assert not all(encoder_requires_grad), "Pretrained model is not fully frozen during fine-tuning"
diff --git a/sub-packages/bionemo-esm2/tests/bionemo/esm2/model/finetune/test_finetune_regressor.py b/sub-packages/bionemo-esm2/tests/bionemo/esm2/model/finetune/test_finetune_regressor.py
deleted file mode 100644
index f9ad25851c..0000000000
--- a/sub-packages/bionemo-esm2/tests/bionemo/esm2/model/finetune/test_finetune_regressor.py
+++ /dev/null
@@ -1,55 +0,0 @@
-# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
-# SPDX-License-Identifier: LicenseRef-Apache2
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-
-import pytest
-
-from bionemo.core.data.load import load
-from bionemo.esm2.data import tokenizer
-from bionemo.esm2.model.finetune.finetune_regressor import (
- ESM2FineTuneSeqConfig,
- ESM2FineTuneSeqModel,
- MegatronMLPHead,
-)
-from bionemo.testing import megatron_parallel_state_utils
-
-
-# To download a 8M internally pre-trained ESM2 model
-pretrain_ckpt_path = load("esm2/nv_8m:2.0")
-
-
-@pytest.fixture
-def config():
- return ESM2FineTuneSeqConfig(encoder_frozen=True, ft_dropout=0.50, initial_ckpt_path=str(pretrain_ckpt_path))
-
-
-@pytest.fixture
-def finetune_seq_model(config):
- with megatron_parallel_state_utils.distributed_model_parallel_state():
- model = config.configure_model(tokenizer.get_tokenizer())
- yield model
-
-
-def test_ft_config(config):
- assert config.initial_ckpt_skip_keys_with_these_prefixes == ["regression_head"]
- assert config.encoder_frozen
- assert config.ft_dropout == 0.50
-
-
-def test_ft_model_initialized(finetune_seq_model):
- assert isinstance(finetune_seq_model, ESM2FineTuneSeqModel)
- assert isinstance(finetune_seq_model.regression_head, MegatronMLPHead)
- assert finetune_seq_model.post_process
- assert not finetune_seq_model.include_embeddings_finetuning
diff --git a/sub-packages/bionemo-esm2/tests/bionemo/esm2/model/finetune/test_finetune_token_classifier.py b/sub-packages/bionemo-esm2/tests/bionemo/esm2/model/finetune/test_finetune_token_classifier.py
deleted file mode 100644
index fe4043831e..0000000000
--- a/sub-packages/bionemo-esm2/tests/bionemo/esm2/model/finetune/test_finetune_token_classifier.py
+++ /dev/null
@@ -1,57 +0,0 @@
-# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
-# SPDX-License-Identifier: LicenseRef-Apache2
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-
-import pytest
-
-from bionemo.core.data.load import load
-from bionemo.esm2.data import tokenizer
-from bionemo.esm2.model.finetune.finetune_token_classifier import (
- ESM2FineTuneTokenConfig,
- ESM2FineTuneTokenModel,
- MegatronConvNetHead,
-)
-from bionemo.testing import megatron_parallel_state_utils
-
-
-# To download a 8M internally pre-trained ESM2 model
-pretrain_ckpt_path = load("esm2/nv_8m:2.0")
-
-
-@pytest.fixture
-def config():
- return ESM2FineTuneTokenConfig(encoder_frozen=True, cnn_dropout=0.1, cnn_hidden_dim=32, cnn_num_classes=5)
-
-
-@pytest.fixture
-def finetune_token_model(config):
- with megatron_parallel_state_utils.distributed_model_parallel_state():
- model = config.configure_model(tokenizer.get_tokenizer())
- yield model
-
-
-def test_ft_config(config):
- assert config.initial_ckpt_skip_keys_with_these_prefixes == ["classification_head"]
- assert config.encoder_frozen
- assert config.cnn_dropout == 0.1
- assert config.cnn_hidden_dim == 32
- assert config.cnn_num_classes == 5
-
-
-def test_ft_model_initialized(finetune_token_model):
- assert isinstance(finetune_token_model, ESM2FineTuneTokenModel)
- assert isinstance(finetune_token_model.classification_head, MegatronConvNetHead)
- assert finetune_token_model.post_process
- assert not finetune_token_model.include_hiddens_finetuning
diff --git a/sub-packages/bionemo-esm2/tests/bionemo/esm2/scripts/test_finetune_esm2.py b/sub-packages/bionemo-esm2/tests/bionemo/esm2/scripts/test_finetune_esm2.py
deleted file mode 100644
index bc929a1abb..0000000000
--- a/sub-packages/bionemo-esm2/tests/bionemo/esm2/scripts/test_finetune_esm2.py
+++ /dev/null
@@ -1,315 +0,0 @@
-# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
-# SPDX-License-Identifier: LicenseRef-Apache2
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-
-from pathlib import Path
-from unittest.mock import patch
-
-import pandas as pd
-import pytest
-from nemo.lightning import io
-
-from bionemo.core.data.load import load
-from bionemo.esm2.model.finetune.dataset import InMemoryPerTokenValueDataset, InMemorySingleValueDataset
-from bionemo.esm2.model.finetune.finetune_regressor import ESM2FineTuneSeqConfig
-from bionemo.esm2.model.finetune.finetune_token_classifier import ESM2FineTuneTokenConfig
-from bionemo.esm2.scripts.finetune_esm2 import finetune_esm2_entrypoint, get_parser, train_model
-from bionemo.testing import megatron_parallel_state_utils
-from bionemo.testing.callbacks import MetricTracker
-
-
-# To download a 8M internally pre-trained ESM2 model
-pretrain_ckpt_path = load("esm2/nv_8m:2.0")
-
-
-def data_to_csv(data, tmp_path):
- """Create a mock protein dataset."""
- csv_file = tmp_path / "protein_dataset.csv"
- # Create a DataFrame
- df = pd.DataFrame(data, columns=["sequences", "labels"])
-
- # Save the DataFrame to a CSV file
- df.to_csv(csv_file, index=False)
- return csv_file
-
-
-@pytest.mark.needs_gpu
-def test_esm2_finetune_token_classifier(
- tmp_path,
- dummy_data_per_token_classification_ft,
- n_steps_train: int = 50,
- seed: int = 42,
-):
- with megatron_parallel_state_utils.distributed_model_parallel_state(seed):
- simple_ft_checkpoint, simple_ft_metrics, trainer = train_model(
- train_data_path=data_to_csv(dummy_data_per_token_classification_ft, tmp_path),
- valid_data_path=data_to_csv(dummy_data_per_token_classification_ft, tmp_path),
- experiment_name="finetune_new_head_token_classification",
- restore_from_checkpoint_path=str(pretrain_ckpt_path),
- num_steps=n_steps_train,
- num_nodes=1,
- devices=1,
- min_seq_length=None,
- max_seq_length=1024,
- result_dir=tmp_path / "finetune",
- limit_val_batches=2,
- val_check_interval=n_steps_train // 2,
- log_every_n_steps=n_steps_train // 2,
- num_dataset_workers=10,
- lr=1e-5,
- micro_batch_size=4,
- accumulate_grad_batches=1,
- resume_if_exists=False,
- precision="bf16-mixed",
- dataset_class=InMemoryPerTokenValueDataset,
- config_class=ESM2FineTuneTokenConfig,
- metric_tracker=MetricTracker(metrics_to_track_val=["loss"], metrics_to_track_train=["loss"]),
- )
-
- weights_ckpt = simple_ft_checkpoint / "weights"
- assert weights_ckpt.exists()
- assert weights_ckpt.is_dir()
- assert io.is_distributed_ckpt(weights_ckpt)
- assert simple_ft_metrics.collection_train["loss"][0] > simple_ft_metrics.collection_train["loss"][-1]
-
- encoder_requires_grad = [
- p.requires_grad for name, p in trainer.model.named_parameters() if "classification_head" not in name
- ]
- assert not all(encoder_requires_grad), "Pretrained model is not fully frozen during fine-tuning"
-
-
-@pytest.mark.needs_gpu
-def test_esm2_finetune_regressor(
- tmp_path,
- dummy_data_single_value_regression_ft,
- n_steps_train: int = 50,
- seed: int = 42,
-):
- with megatron_parallel_state_utils.distributed_model_parallel_state(seed):
- simple_ft_checkpoint, simple_ft_metrics, trainer = train_model(
- train_data_path=data_to_csv(dummy_data_single_value_regression_ft, tmp_path),
- valid_data_path=data_to_csv(dummy_data_single_value_regression_ft, tmp_path),
- experiment_name="finetune_new_head_regression",
- restore_from_checkpoint_path=str(pretrain_ckpt_path),
- num_steps=n_steps_train,
- num_nodes=1,
- devices=1,
- min_seq_length=None,
- max_seq_length=1024,
- result_dir=tmp_path / "finetune",
- limit_val_batches=2,
- val_check_interval=n_steps_train // 2,
- log_every_n_steps=n_steps_train // 2,
- num_dataset_workers=10,
- lr=1e-5,
- micro_batch_size=4,
- accumulate_grad_batches=1,
- resume_if_exists=False,
- precision="bf16-mixed",
- dataset_class=InMemorySingleValueDataset,
- config_class=ESM2FineTuneSeqConfig,
- metric_tracker=MetricTracker(metrics_to_track_val=["loss"], metrics_to_track_train=["loss"]),
- )
-
- weights_ckpt = simple_ft_checkpoint / "weights"
- assert weights_ckpt.exists()
- assert weights_ckpt.is_dir()
- assert io.is_distributed_ckpt(weights_ckpt)
- assert simple_ft_metrics.collection_train["loss"][0] > simple_ft_metrics.collection_train["loss"][-1]
-
- encoder_requires_grad = [
- p.requires_grad for name, p in trainer.model.named_parameters() if "regression_head" not in name
- ]
- assert not all(encoder_requires_grad), "Pretrained model is not fully frozen during fine-tuning"
-
-
-@pytest.fixture
-def mock_train_model():
- with patch("bionemo.esm2.scripts.finetune_esm2.train_model") as mock_train:
- yield mock_train
-
-
-@pytest.fixture
-def mock_parser_args():
- """Fixture to create mock arguments for the parser."""
- return [
- "--train-data-path",
- str(Path("train.csv")),
- "--valid-data-path",
- str(Path("valid.csv")),
- "--num-gpus",
- "1",
- "--num-nodes",
- "1",
- "--min-seq-length",
- "512",
- "--max-seq-length",
- "1024",
- "--result-dir",
- str(Path("./results")),
- "--lr",
- "0.001",
- ]
-
-
-def test_finetune_esm2_entrypoint(mock_train_model, mock_parser_args):
- """Test the finetune_esm2_entrypoint function with mocked arguments."""
- with patch("sys.argv", ["finetune_esm2_entrypoint.py"] + mock_parser_args):
- finetune_esm2_entrypoint()
-
- # Check if train_model was called once
- mock_train_model.assert_called_once()
-
- # Check if the arguments were passed correctly
- called_kwargs = mock_train_model.call_args.kwargs
- assert called_kwargs["train_data_path"] == Path("train.csv")
- assert called_kwargs["valid_data_path"] == Path("valid.csv")
- assert called_kwargs["devices"] == 1
- assert called_kwargs["num_nodes"] == 1
- assert called_kwargs["min_seq_length"] == 512
- assert called_kwargs["max_seq_length"] == 1024
- assert called_kwargs["lr"] == 0.001
- assert called_kwargs["result_dir"] == Path("./results")
-
-
-def test_get_parser():
- """Test the argument parser with all possible arguments."""
- parser = get_parser()
- args = parser.parse_args(
- [
- "--train-data-path",
- "train.csv",
- "--valid-data-path",
- "valid.csv",
- "--precision",
- "bf16-mixed",
- "--lr",
- "0.001",
- "--create-tensorboard-logger",
- "--resume-if-exists",
- "--result-dir",
- "./results",
- "--experiment-name",
- "esm2_experiment",
- "--wandb-entity",
- "my_team",
- "--wandb-project",
- "ft_project",
- "--wandb-tags",
- "tag1",
- "tag2",
- "--wandb-group",
- "group1",
- "--wandb-id",
- "1234",
- "--wandb-anonymous",
- "--wandb-log-model",
- "--wandb-offline",
- "--num-gpus",
- "2",
- "--num-nodes",
- "1",
- "--num-steps",
- "1000",
- "--num-dataset-workers",
- "4",
- "--val-check-interval",
- "500",
- "--log-every-n-steps",
- "100",
- "--min-seq-length",
- "512",
- "--max-seq-length",
- "1024",
- "--limit-val-batches",
- "2",
- "--micro-batch-size",
- "32",
- "--pipeline-model-parallel-size",
- "2",
- "--tensor-model-parallel-size",
- "2",
- "--accumulate-grad-batches",
- "2",
- "--save-last-checkpoint",
- "--metric-to-monitor-for-checkpoints",
- "val_loss",
- "--save-top-k",
- "5",
- "--restore-from-checkpoint-path",
- "./checkpoint",
- "--nsys-profiling",
- "--nsys-start-step",
- "10",
- "--nsys-end-step",
- "50",
- "--nsys-ranks",
- "0",
- "1",
- "--no-overlap-grad-reduce",
- "--no-overlap-param-gather",
- "--no-average-in-collective",
- "--grad-reduce-in-fp32",
- "--dataset-class",
- "InMemoryPerTokenValueDataset",
- "--config-class",
- "ESM2FineTuneTokenConfig",
- ]
- )
-
- # Assertions for all arguments
- assert args.train_data_path == Path("train.csv")
- assert args.valid_data_path == Path("valid.csv")
- assert args.precision == "bf16-mixed"
- assert args.lr == 0.001
- assert args.create_tensorboard_logger is True
- assert args.resume_if_exists is True
- assert args.result_dir == Path("./results")
- assert args.experiment_name == "esm2_experiment"
- assert args.wandb_entity == "my_team"
- assert args.wandb_project == "ft_project"
- assert args.wandb_tags == ["tag1", "tag2"]
- assert args.wandb_group == "group1"
- assert args.wandb_id == "1234"
- assert args.wandb_anonymous is True
- assert args.wandb_log_model is True
- assert args.wandb_offline is True
- assert args.num_gpus == 2
- assert args.num_nodes == 1
- assert args.num_steps == 1000
- assert args.num_dataset_workers == 4
- assert args.val_check_interval == 500
- assert args.log_every_n_steps == 100
- assert args.min_seq_length == 512
- assert args.max_seq_length == 1024
- assert args.limit_val_batches == 2
- assert args.micro_batch_size == 32
- assert args.pipeline_model_parallel_size == 2
- assert args.tensor_model_parallel_size == 2
- assert args.accumulate_grad_batches == 2
- assert args.save_last_checkpoint is True
- assert args.metric_to_monitor_for_checkpoints == "val_loss"
- assert args.save_top_k == 5
- assert args.restore_from_checkpoint_path == Path("./checkpoint")
- assert args.nsys_profiling is True
- assert args.nsys_start_step == 10
- assert args.nsys_end_step == 50
- assert args.nsys_ranks == [0, 1]
- assert args.no_overlap_grad_reduce is True
- assert args.no_overlap_param_gather is True
- assert args.no_average_in_collective is True
- assert args.grad_reduce_in_fp32 is True
- assert args.dataset_class == InMemoryPerTokenValueDataset
- assert args.config_class == ESM2FineTuneTokenConfig
diff --git a/sub-packages/bionemo-esm2/tests/bionemo/esm2/scripts/test_infer_esm2.py b/sub-packages/bionemo-esm2/tests/bionemo/esm2/scripts/test_infer_esm2.py
index 9575fafa9e..e601ce18ed 100644
--- a/sub-packages/bionemo-esm2/tests/bionemo/esm2/scripts/test_infer_esm2.py
+++ b/sub-packages/bionemo-esm2/tests/bionemo/esm2/scripts/test_infer_esm2.py
@@ -19,17 +19,45 @@
import pandas as pd
import pytest
import torch
+from torch.utils.data import DataLoader
from bionemo.core.data.load import load
from bionemo.core.utils.dtypes import get_autocast_dtype
from bionemo.esm2.api import ESM2Config
from bionemo.esm2.data.tokenizer import get_tokenizer
+from bionemo.esm2.model.finetune.datamodule import ESM2FineTuneDataModule, InMemoryCSVDataset
from bionemo.esm2.scripts.infer_esm2 import infer_model
from bionemo.llm.data import collate
from bionemo.llm.lightning import batch_collator
from bionemo.llm.utils.callbacks import IntervalT
+# Function to check GPU memory
+def check_gpu_memory(threshold_gb):
+ if torch.cuda.is_available():
+ gpu_memory = torch.cuda.get_device_properties(0).total_memory / (1024**3) # Memory in GB
+ return gpu_memory < threshold_gb
+ return False
+
+
+@pytest.fixture
+def dummy_protein_sequences():
+ """Create a list of artificial protein sequences"""
+ artificial_sequence_data = [
+ "TLILGWSDKLGSLLNQLAIANESLGGGTIAVMAERDKEDMELDIGKMEFDFKGTSVI",
+ "LYSGDHSTQGARFLRDLAENTGRAEYELLSLF",
+ "GRFNVWLGGNESKIRQVLKAVKEIGVSPTLFAVYEKN",
+ "DELTALGGLLHDIGKPVQRAGLYSGDHSTQGARFLRDLAENTGRAEYELLSLF",
+ "KLGSLLNQLAIANESLGGGTIAVMAERDKEDMELDIGKMEFDFKGTSVI",
+ "LFGAIGNAISAIHGQSAVEELVDAFVGGARISSAFPYSGDTYYLPKP",
+ "LGGLLHDIGKPVQRAGLYSGDHSTQGARFLRDLAENTGRAEYELLSLF",
+ "LYSGDHSTQGARFLRDLAENTGRAEYELLSLF",
+ "ISAIHGQSAVEELVDAFVGGARISSAFPYSGDTYYLPKP",
+ "SGSKASSDSQDANQCCTSCEDNAPATSYCVECSEPLCETCVEAHQRVKYTKDHTVRSTGPAKT",
+ ]
+ return artificial_sequence_data
+
+
@pytest.fixture
def dummy_protein_csv(tmp_path, dummy_protein_sequences):
"""Create a mock protein dataset."""
@@ -42,6 +70,16 @@ def dummy_protein_csv(tmp_path, dummy_protein_sequences):
return csv_file
+@pytest.fixture
+def dataset(dummy_protein_csv):
+ return InMemoryCSVDataset(dummy_protein_csv)
+
+
+@pytest.fixture
+def data_module(dataset):
+ return ESM2FineTuneDataModule(predict_dataset=dataset)
+
+
@pytest.fixture
def padded_tokenized_sequences(dummy_protein_sequences):
tokenizer = get_tokenizer()
@@ -53,6 +91,49 @@ def padded_tokenized_sequences(dummy_protein_sequences):
return collated_batch["text"]
+def test_in_memory_csv_dataset(dataset):
+ assert len(dataset) > 0
+ sample = dataset[0]
+ assert isinstance(sample, dict)
+ assert "text" in sample
+ assert "labels" in sample
+
+
+def test_in_memory_csv_dataset_load_data(dataset, dummy_protein_csv):
+ sequences, labels = dataset.load_data(dummy_protein_csv)
+ assert isinstance(sequences, list)
+ assert isinstance(labels, list)
+
+
+def test_esm2_fine_tune_data_module_init(data_module):
+ assert data_module.train_dataset is None
+ assert data_module.valid_dataset is None
+ assert data_module.predict_dataset is not None
+
+
+def test_esm2_fine_tune_data_module_predict_dataloader(data_module):
+ predict_dataloader = data_module.predict_dataloader()
+ assert isinstance(predict_dataloader, DataLoader)
+ batch = next(iter(predict_dataloader))
+ assert isinstance(batch, dict)
+ assert "text" in batch
+
+
+def test_esm2_fine_tune_data_module_setup(data_module):
+ with pytest.raises(RuntimeError):
+ data_module.setup("fit")
+
+
+def test_esm2_fine_tune_data_module_train_dataloader(data_module):
+ with pytest.raises(AttributeError):
+ data_module.train_dataloader()
+
+
+def test_esm2_fine_tune_data_module_val_dataloader(data_module):
+ with pytest.raises(AttributeError):
+ data_module.val_dataloader()
+
+
@pytest.mark.parametrize("precision", ["fp32", "bf16-mixed"])
@pytest.mark.parametrize("prediction_interval", get_args(IntervalT))
def test_infer_runs(