diff --git a/docs/docs/user-guide/examples/bionemo-esm2/finetune.ipynb b/docs/docs/user-guide/examples/bionemo-esm2/finetune.ipynb
new file mode 100644
index 0000000000..aa26cd72ff
--- /dev/null
+++ b/docs/docs/user-guide/examples/bionemo-esm2/finetune.ipynb
@@ -0,0 +1,898 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "[![ Click here to deploy.](https://uohmivykqgnnbiouffke.supabase.co/storage/v1/object/public/landingpage/brevdeploynavy.svg)](https://console.brev.dev/launchable/deploy?launchableID=env-2rPWpPzzJIxq7SMRJIQehCxBymV)\n",
+ "\n",
+ "
NOTE It takes about 10 minutes to deploy this notebook as a Launchable. As of this writing, we are working on a free tier so a credit card may be required. You can reach out to your NVIDIA rep for credits.
"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "# ESM-2 Fine-tuning"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "vscode": {
+ "languageId": "plaintext"
+ }
+ },
+ "source": [
+ "The [ESM-2](https://www.science.org/doi/abs/10.1126/science.ade2574) model is a transformer-based protein language model that has achieved state-of-the-art results in various protein-related tasks. When fine-tuning ESM2, the task-head plays a crucial role. A task head refers to the additional layer or set of layers added on top of a pre-trained model, like the ESM-2 transformer-based protein language model, to adapt it for a specific downstream task. As a part of transfer learning, a pre-trained model is often utilized to learn generic features from a large-scale dataset. However, these features might not be directly applicable to the specific task at hand. By incorporating a task head, which consists of learnable parameters, the model can adapt and specialize to the target task. The task head serves as a flexible and adaptable component that learns task-specific representations by leveraging the pre-trained features as a foundation. Through fine-tuning, the task head enables the model to learn and extract task-specific patterns, improving performance and addressing the nuances of the downstream task. It acts as a critical bridge between the pre-trained model and the specific task, enabling efficient and effective transfer of knowledge."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ " NOTE This tutorial will guide you through the steps for creating a basic regression fine-tuning task for simplicity. The utilities described in this tutorial are available in:\n",
+ "\n",
+ "
bionemo.esm2.model.finetune.finetune_regressor
\n",
+ "\n",
+ "The techniques demonstrated here can be adapted for classification and per-token classification tasks. Utilities needed for secondary structure prediction (token-level classification) are available in \n",
+ "\n",
+ "
bionemo.esm2.model.finetune.finetune_token_classifier
\n",
+ "\n",
+ "In the second part of the tutorial, we will cover loading a pre-trained model, fine-tuning it for both regression and per-token classification tasks, and using the fine-tuned models for inference. For instructions on pre-training the ESM-2 model, please refer to the [ESM-2 Pretraining](./pretrain.md) tutorial.
"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Building a Regression Fine-tune Module"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Wwe need to define some key classes to successfully build a fine-tuning module in BioNeMo framework: \n",
+ "\n",
+ "1. **Loss Reduction Class** - To compute the supervised fine-tuning loss.\n",
+ "2. **Fine-Tuned Model Head** - Downstream task head model.\n",
+ "3. **Fine-Tuned Model** - Model that combines ESM-2 with the task head model.\n",
+ "4. **Fine-Tuning Config** - Configures the fine-tuning model and loss to use in the training and inference framework.\n",
+ "5. **Dataset** - Training and inference datasets for ESM2 fine-tuning."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### 1 - Loss Reduction Class"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "A class for calculating the supervised loss of the fine-tune model from targets. We inherit from Megatron Bert Masked Language Model Loss (`BERTMLMLossWithReduction`) and override the `forward()` pass to compute MSE loss of the regression head within a micro-batch. The `reduce()` method is used for computing the average over the micro-batches and is only used for logging.\n",
+ "\n",
+ "```python\n",
+ "class RegressorLossReduction(BERTMLMLossWithReduction):\n",
+ " def forward(\n",
+ " self, batch: Dict[str, torch.Tensor], forward_out: Dict[str, torch.Tensor]\n",
+ " ) -> Tuple[torch.Tensor, Union[PerTokenLossDict, SameSizeLossDict]]:\n",
+ "\n",
+ " regression_output = forward_out[\"regression_output\"]\n",
+ " targets = batch[\"labels\"].to(dtype=regression_output.dtype) # [b, 1]\n",
+ "\n",
+ " loss = torch.nn.functional.mse_loss(regression_output, targets)\n",
+ " return loss, {\"avg\": loss}\n",
+ "\n",
+ " def reduce(self, losses_reduced_per_micro_batch: Sequence[ReductionT]) -> torch.Tensor:\n",
+ " losses = torch.stack([loss[\"avg\"] for loss in losses_reduced_per_micro_batch])\n",
+ " return losses.mean()\n",
+ "```"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### 2 - Fine-Tuned Model Head"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "An MLP class for sequence-level regression. This class inherits `MegatronModule` and uses the fine-tune config (`TransformerConfig`) to configure the regression head for the fine-tuned ESM-2 model.\n",
+ "\n",
+ "```python\n",
+ "class MegatronMLPHead(MegatronModule):\n",
+ " def __init__(self, config: TransformerConfig):\n",
+ " super().__init__(config)\n",
+ " layer_sizes = [config.hidden_size, 256, 1]\n",
+ " self.linear_layers = torch.nn.ModuleList(\n",
+ " [torch.nn.Linear(i, o) for i, o in zip(layer_sizes[:-1], layer_sizes[1:])]\n",
+ " )\n",
+ " self.act = torch.nn.ReLU()\n",
+ " self.dropout = torch.nn.Dropout(p=config.ft_dropout)\n",
+ "\n",
+ " def forward(self, hidden_states: torch.Tensor) -> List[torch.Tensor]:\n",
+ " ...\n",
+ "```"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### 3 - Fine-Tuned Model"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "A fine-tuned ESM-2 model class for token classification tasks. This class inherits from the `ESM2Model` class and adds the custom regression head `MegatronMLPHead` the we created in the previous step. Optionally one can freeze all or parts of the encoder by parsing through the model parameters in the model constructor.\n",
+ "\n",
+ "```python\n",
+ "class ESM2FineTuneSeqModel(ESM2Model):\n",
+ " def __init__(self, config, *args, post_process: bool = True, include_embeddings: bool = False, **kwargs):\n",
+ " super().__init__(config, *args, post_process=post_process, include_embeddings=True, **kwargs)\n",
+ "\n",
+ " # freeze encoder parameters\n",
+ " if config.encoder_frozen:\n",
+ " for _, param in self.named_parameters():\n",
+ " param.requires_grad = False\n",
+ "\n",
+ " if post_process:\n",
+ " self.regression_head = MegatronMLPHead(config)\n",
+ "\n",
+ " def forward(self, *args, **kwargs,):\n",
+ " output = super().forward(*args, **kwargs)\n",
+ " ...\n",
+ " output[\"regression_output\"] = self.regression_head(embeddings)\n",
+ " return output\n",
+ "```"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### 4 - Fine-Tuning Config"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "A `dataclass` that configures the fine-tuned ESM-2 model. In this example `ESM2FineTuneSeqConfig` inherits from `ESM2GenericConfig` and adds custom arguments to setup the fine-tuned model. The `configure_model()` method of this `dataclass` is called within the `Lightning` module to call the model constructor with the `dataclass` arguments.\n",
+ "\n",
+ "The common arguments among different fine-tuning tasks are\n",
+ "\n",
+ "- `model_cls`: The fine-tune model class defined in previous step (`ESM2FineTuneSeqModel`)\n",
+ "- `initial_ckpt_path`: BioNeMo 2.0 ESM-2 pre-trained checkpoint\n",
+ "- `initial_ckpt_skip_keys_with_these_prefixes`: skips keys when loading parameters from a checkpoint. For example, we should not look for `regression_head` in the pre-trained checkpoint.\n",
+ "- `get_loss_reduction_class()`: Implements selection of the appropriate `MegatronLossReduction` class that we defined in the first step of this tutorial.\n",
+ "\n",
+ "```python\n",
+ "\n",
+ "@dataclass\n",
+ "class ESM2FineTuneSeqConfig(\n",
+ " ESM2GenericConfig[ESM2FineTuneSeqModel, RegressorLossReduction], iom.IOMixinWithGettersSetters\n",
+ "):\n",
+ " model_cls: Type[ESM2FineTuneSeqModel] = ESM2FineTuneSeqModel\n",
+ " # The following checkpoint path is for nemo2 checkpoints. Config parameters not present in\n",
+ " # self.override_parent_fields will be loaded from the checkpoint and override those values here.\n",
+ " initial_ckpt_path: str | None = None\n",
+ " # typical case is fine-tune the base biobert that doesn't have this head. If you are instead loading a checkpoint\n",
+ " # that has this new head and want to keep using these weights, please drop this next line or set to []\n",
+ " initial_ckpt_skip_keys_with_these_prefixes: List[str] = field(default_factory=lambda: [\"regression_head\"])\n",
+ "\n",
+ " encoder_frozen: bool = True # freeze encoder parameters\n",
+ " ft_dropout: float = 0.25 # MLP layer dropout\n",
+ "\n",
+ " def get_loss_reduction_class(self) -> Type[MegatronLossReduction]:\n",
+ " return RegressorLossReduction\n",
+ "```"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### 5 - Dataset"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "\n",
+ "We will use a sample dataset for demonstration purposes. Create a dataset class by extending ```bionemo.esm2.model.finetune.dataset.InMemoryProteinDataset```. The `InMemoryProteinDataset` has a `classmethod` (`from_csv`) that reads data from a CSV file that has `sequences` and optionally `labels` columns. It is important to override the `transform_label()` method that returns a `torch.Tensor` containing the label in correct format. As an example we can use this method to add custom tokenization if `label` is a string."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "\n",
+ "The custom dataset class will be appropriate (found in ```bionemo.esm2.model.finetune.dataset.InMemorySingleValueDataset```) as it facilitates predicting on a single value. An excerpt from the class is shown below. This example dataset has a class method `from_csv()` that expects a `data_path` to a CSV file that has `sequences`, and `labels` columns.\n",
+ "\n",
+ "```python\n",
+ "class InMemorySingleValueDataset(InMemoryProteinDataset):\n",
+ " def __init__(\n",
+ " self,\n",
+ " sequences: pd.Series,\n",
+ " labels: pd.Series | None = None,\n",
+ " tokenizer: tokenizer.BioNeMoESMTokenizer = tokenizer.get_tokenizer(),\n",
+ " seed: int = np.random.SeedSequence().entropy,\n",
+ " ):\n",
+ " super().__init__(sequences, labels, tokenizer, seed)\n",
+ "\n",
+ " def transform_label(self, label: float) -> Tensor:\n",
+ " return torch.tensor([label], dtype=torch.float)\n",
+ "```\n",
+ "\n",
+ "The `transform_label` method allows for custom transformation of raw labels by casting or tokenization and need to be adjusted based on the data. Here we use this method to create a `float` tensor of the regression value."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "#### DataModule"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "To coordinate the creation of training, validation and testing datasets from your data, we need to use a `datamodule` class. To do this we can directly use or extend the ```ESM2FineTuneDataModule``` class (located at ```bionemo.esm2.model.finetune.datamodule.ESM2FineTuneDataModule```) which defines helpful abstract methods that use your dataset class.\n",
+ "\n",
+ "```python\n",
+ "dataset = InMemorySingleValueDataset.from_csv(data_path)\n",
+ "data_module = ESM2FineTuneDataModule(\n",
+ " train_dataset=dataset,\n",
+ " valid_dataset=dataset\n",
+ " micro_batch_size=4, # size of a batch to be processed in a device\n",
+ " global_batch_size=8, # size of batch across all devices. Should be multiple of micro_batch_size\n",
+ ")\n",
+ "```"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "In the next part of this tutorial we will prepare the input needed to run regression and per-token-classification fine-tuning examples."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Setup and Assumptions"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "All commands should be executed inside the BioNeMo docker container, which has all ESM-2 dependencies pre-installed. For more information on how to build or pull the BioNeMo2 container, refer to the [Initialization Guide](../../getting-started/initialization-guide.md)."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ " NOTE Some of the cells below generate long text output. We're using
%%capture --no-display --no-stderr cell_output
to suppress this output. Comment or delete this line in the cells below to restore full output.
"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### Import Required Libraries"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 1,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "%%capture --no-display --no-stderr cell_output\n",
+ "\n",
+ "import os\n",
+ "import shutil\n",
+ "import pandas as pd\n",
+ "\n",
+ "import warnings\n",
+ "warnings.filterwarnings('ignore')\n",
+ "warnings.simplefilter('ignore')"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### Work Directory"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Set the work directory to store data and results:\n",
+ "\n",
+ " NOTE We set the following to clean up the work directory created by this notebook
cleanup : bool = True
"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 2,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "cleanup : bool = True"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 3,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Directory '/workspace/bionemo2/esm2_finetune_tutorial' created.\n"
+ ]
+ }
+ ],
+ "source": [
+ "work_dir=\"/workspace/bionemo2/esm2_finetune_tutorial\"\n",
+ "\n",
+ "if cleanup and os.path.exists(work_dir):\n",
+ " shutil.rmtree(work_dir)\n",
+ "\n",
+ "if not os.path.exists(work_dir):\n",
+ " os.makedirs(work_dir)\n",
+ " print(f\"Directory '{work_dir}' created.\")\n",
+ "else:\n",
+ " print(f\"Directory '{work_dir}' already exists.\")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### Download Pre-trained Model Checkpoints"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "The following code will download the internally pre-trained model, `esm2/nv_8m:2.0`, from the NGC registry. Please refer to [ESM-2 Model Overview](../../../models/ESM-2/index.md) for a list of available checkpoints."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 4,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "/home/ubuntu/.cache/bionemo/b4ea4d52eea8a25d2c2838617ff678f0da22d384cee195b0c192686816078dcd-esm2_8m_checkpoint.tar.gz.untar\n"
+ ]
+ }
+ ],
+ "source": [
+ "from bionemo.core.data.load import load\n",
+ "\n",
+ "checkpoint_path = load(\"esm2/nv_8m:2.0\")\n",
+ "print(checkpoint_path)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "The above example is downloading an internally trained 8M ESM-2 model. The pre-trained checkpoints can be downloaded from NGC resources using either the following bash command or the `load` function in `bionemo.core.data.load` as shown above.\n",
+ "\n",
+ "```bash\n",
+ "download_bionemo_data esm2/650m:2.0\n",
+ "```\n",
+ "\n",
+ "which returns the checkpoint path (e.g. `.../.cache/bionemo/975d29ee980fcb08c97401bbdfdcf8ce-esm2_650M_nemo2.tar.gz.untar`)\n",
+ "\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Fine-tuning"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "We can take advantage of the ESM2 fine-tuning script in ```bionemo.esm2.scripts.finetune_esm2``` or use the ```finetune_esm2``` executable the fine-tuning process given:\n",
+ "\n",
+ "- Pre-trained checkpoint of ESM2\n",
+ "- Finetune config class name that configures the finetune model and loss reduction\n",
+ "- Path to train and validation CSV data files\n",
+ "- Dataset class name\n",
+ "\n",
+ "To get the full list of arguments to tune a finetuning run use:\n",
+ "```bash\n",
+ "finetune_esm2 --help \n",
+ "```\n",
+ "For a detailed description of training loop and the arguments please refer to the [ESM-2 Pretraining](./pretrain.md) tutorial."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ " NOTE\n",
+ "\n",
+ "Due to Megatron limitations, the log produced by the training run iterates on steps/iterations and not epochs. Therefore, `Training epoch` counter stays at value zero while `iteration` and `global_step` increase during the course of training (example in the following).\n",
+ "\n",
+ "
\n",
+ "Training epoch 0, iteration | ... | global_step: | reduced_train_loss: ... | val_loss: ...\n",
+ "
\n",
+ "\n",
+ "to achieve the same epoch-based effect while training, please choose the number of training steps (`num_steps`) so that:\n",
+ "\n",
+ "
\n",
+ "num_steps * global_batch_size = len(dataset) * desired_num_epochs\n",
+ "
"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### Regression"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "For the purposes of this demo, we'll assume dataset consists of small set of protein sequences with a target value of `len(sequence) / 100.0` as their labels:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 5,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import pandas as pd\n",
+ "\n",
+ "artificial_sequence_data = [\n",
+ " \"TLILGWSDKLGSLLNQLAIANESLGGGTIAVMAERDKEDMELDIGKMEFDFKGTSVI\",\n",
+ " \"LYSGDHSTQGARFLRDLAENTGRAEYELLSLF\",\n",
+ " \"GRFNVWLGGNESKIRQVLKAVKEIGVSPTLFAVYEKN\",\n",
+ " \"DELTALGGLLHDIGKPVQRAGLYSGDHSTQGARFLRDLAENTGRAEYELLSLF\",\n",
+ " \"KLGSLLNQLAIANESLGGGTIAVMAERDKEDMELDIGKMEFDFKGTSVI\",\n",
+ " \"LFGAIGNAISAIHGQSAVEELVDAFVGGARISSAFPYSGDTYYLPKP\",\n",
+ " \"LGGLLHDIGKPVQRAGLYSGDHSTQGARFLRDLAENTGRAEYELLSLF\",\n",
+ " \"LYSGDHSTQGARFLRDLAENTGRAEYELLSLF\",\n",
+ " \"ISAIHGQSAVEELVDAFVGGARISSAFPYSGDTYYLPKP\",\n",
+ " \"SGSKASSDSQDANQCCTSCEDNAPATSYCVECSEPLCETCVEAHQRVKYTKDHTVRSTGPAKT\",\n",
+ "]\n",
+ "\n",
+ "data = [(seq, len(seq)/100.0) for seq in artificial_sequence_data]\n",
+ "\n",
+ "# Create a DataFrame\n",
+ "df = pd.DataFrame(data, columns=[\"sequences\", \"labels\"])\n",
+ "\n",
+ "# Save the DataFrame to a CSV file\n",
+ "data_path = os.path.join(work_dir, \"regression_data.csv\")\n",
+ "df.to_csv(data_path, index=False)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 6,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "%%capture --no-display --no-stderr cell_output\n",
+ "\n",
+ "! finetune_esm2 \\\n",
+ " --restore-from-checkpoint-path {checkpoint_path} \\\n",
+ " --train-data-path {data_path} \\\n",
+ " --valid-data-path {data_path} \\\n",
+ " --config-class ESM2FineTuneSeqConfig \\\n",
+ " --dataset-class InMemorySingleValueDataset \\\n",
+ " --experiment-name \"regression\" \\\n",
+ " --num-steps 50 \\\n",
+ " --num-gpus 1 \\\n",
+ " --val-check-interval 10 \\\n",
+ " --log-every-n-steps 10 \\\n",
+ " --lr 5e-3 \\\n",
+ " --result-dir {work_dir} \\\n",
+ " --micro-batch-size 2 \\\n",
+ " --num-gpus 1 \\\n",
+ " --precision \"bf16-mixed\"\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "The previous cell executes the finetuning and saves the checkpoints at the end of the run. The checkpoint path is logged at the end of the finetuning log file: \n",
+ "\n",
+ "```\n",
+ "[NeMo I $$$$-$$-$$ 22:04:28 nemo_logging:393] Async checkpoint save for step 50 (/workspace/bionemo2/esm2_finetune_tutorial/regression/dev/checkpoints/checkpoint-step=49-consumed_samples=100.0-last-v1.ckpt) finalized successfully.\n",
+ "```\n",
+ "\n",
+ "To avoid long text output from the previous cell, the log is captured and stored into the `cell_output` variable. To visualize the log file uncomment and execute the next cell:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 7,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# print(cell_output.stdout)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "We can now use the checkpoint stored in the previous step to run inference. We will drop the `.ckpt` from the checkpoint path and provide that to the `--checkpoint-path` argument of `infer_esm2` executable.\n",
+ "\n",
+ "The input `--data-path` for inference is a CSV file with `sequences` column. It is also required to provide the appropriate `--config-class` name to load the model from the checkpoint. For a detailed description of inference arguments please refer to the [ESM-2 Inference](./inference.ipynb) tutorial."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 8,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# Create a DataFrame\n",
+ "df = pd.DataFrame(artificial_sequence_data, columns=[\"sequences\"])\n",
+ "\n",
+ "# Save the DataFrame to a CSV file\n",
+ "data_path = os.path.join(work_dir, \"sequences.csv\")\n",
+ "df.to_csv(data_path, index=False)\n",
+ "\n",
+ "checkpoint_path = f\"{work_dir}/regression/dev/checkpoints/checkpoint-step=49-consumed_samples=100.0-last\"\n",
+ "results_path = f\"{work_dir}/regression/infer/\""
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 9,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "%%capture --no-display --no-stderr cell_output\n",
+ "\n",
+ "! infer_esm2 --checkpoint-path {checkpoint_path} \\\n",
+ " --config-class ESM2FineTuneSeqConfig \\\n",
+ " --data-path {data_path} \\\n",
+ " --results-path {results_path} \\\n",
+ " --micro-batch-size 3 \\\n",
+ " --num-gpus 1 \\\n",
+ " --precision \"bf16-mixed\" \\\n",
+ " --include-embeddings \\\n",
+ " --include-input-ids"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "The inference results are written into a `.pt` file which can be loaded using PyTorch library:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 10,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "input_ids\ttorch.Size([10, 1024])\n",
+ "embeddings\ttorch.Size([10, 320])\n",
+ "regression_output\ttorch.Size([10, 1])\n"
+ ]
+ }
+ ],
+ "source": [
+ "import torch\n",
+ "results = torch.load(f\"{results_path}/predictions__rank_0.pt\")\n",
+ "\n",
+ "for key, val in results.items():\n",
+ " if val is not None:\n",
+ " print(f'{key}\\t{val.shape}')"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### Toke-level Classification data"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "For this task we assign secondary structure label to each token in the sequence:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 11,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "secondary_structure_labels = [\n",
+ " \"EEEECCCCCHHHHHHHHHHHHHHHCCCEEEEEECCCHHHHHHHHHCCCCCCCCCEEE\",\n",
+ " \"CCCCCHHHHHHHHHHHHHHCCCCCHHHHHHCC\",\n",
+ " \"HHHHHCCCCCHHHHHHHHHHHHHHCCCHHHHHHHHHH\",\n",
+ " \"HHHHHHHHHHCCCHHHHHCCCCCCCCHHHHHHHHHHHHHHCCCCCHHHHHHCC\",\n",
+ " \"CHHHHHHHHHHHHHHHCCCEEEEEECCCHHHHHHHHHCCCCCCCCCEEE\",\n",
+ " \"HHHHHHHHHHHHHCHHHHHHHHHHHHCCCEECCCEEEECCEEEEECC\",\n",
+ " \"HHHHHCCCHHHHHCCCCCCCCHHHHHHHHHHHHHHCCCCCHHHHHHCC\",\n",
+ " \"CCCCCHHHHHHHHHHHHHHCCCCCHHHHHHCC\",\n",
+ " \"HHHHHCHHHHHHHHHHHHCCCEECCCEEEECCEEEEECC\",\n",
+ " \"CCCCCCCCCCCCCCCCCCCCCCCCCCEEECCCCEEECHHHHHHHHHCCCCCCCCEEECCCCCC\",\n",
+ "]\n",
+ "\n",
+ "data = [(seq, label) for (seq, label) in zip(artificial_sequence_data, secondary_structure_labels)]\n",
+ "\n",
+ "# Create a DataFrame\n",
+ "df = pd.DataFrame(data, columns=[\"sequences\", \"labels\"])\n",
+ "\n",
+ "# Save the DataFrame to a CSV file\n",
+ "data_path = os.path.join(work_dir, \"token_classification_data.csv\")\n",
+ "df.to_csv(data_path, index=False)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 12,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "%%capture --no-display --no-stderr cell_output\n",
+ "\n",
+ "! finetune_esm2 \\\n",
+ " --restore-from-checkpoint-path {checkpoint_path} \\\n",
+ " --train-data-path {data_path} \\\n",
+ " --valid-data-path {data_path} \\\n",
+ " --config-class ESM2FineTuneTokenConfig \\\n",
+ " --dataset-class InMemoryPerTokenValueDataset \\\n",
+ " --experiment-name \"token_level_classification\" \\\n",
+ " --num-steps 50 \\\n",
+ " --num-gpus 1 \\\n",
+ " --val-check-interval 10 \\\n",
+ " --log-every-n-steps 10 \\\n",
+ " --lr 5e-3 \\\n",
+ " --result-dir {work_dir} \\\n",
+ " --micro-batch-size 2 \\\n",
+ " --num-gpus 1 \\\n",
+ " --precision \"bf16-mixed\"\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "The previous cell executes the finetuning and saves the checkpoints at the end of the run. The checkpoint path is logged at the end of the finetuning log file: \n",
+ "\n",
+ "```\n",
+ "[NeMo I $$$$-$$-$$ 22:16:46 nemo_logging:393] Async checkpoint save for step 50 (/workspace/bionemo2/esm2_finetune_tutorial/token_level_classification/dev/checkpoints/checkpoint-step=49-consumed_samples=100.0-last.ckpt) finalized successfully.\n",
+ "```\n",
+ "\n",
+ "To avoid long text output from the previous cell, the log is captured and stored into the `cell_output` variable. To visualize the log file uncomment and execute the next cell:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 13,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# print(cell_output.stdout)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "We can now use the checkpoint stored in the previous step to run inference. We will drop the `.ckpt` from the checkpoint path and provide that to the `--checkpoint-path` argument of `infer_esm2` executable.\n",
+ "\n",
+ "The input `--data-path` for inference is a CSV file with `sequences` column. It is also required to provide the appropriate `--config-class` name to load the model from the checkpoint. For a detailed description of inference arguments please refer to the [ESM-2 Inference](./inference.ipynb) tutorial."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 14,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# Create a DataFrame\n",
+ "df = pd.DataFrame(artificial_sequence_data, columns=[\"sequences\"])\n",
+ "\n",
+ "# Save the DataFrame to a CSV file\n",
+ "data_path = os.path.join(work_dir, \"sequences.csv\")\n",
+ "df.to_csv(data_path, index=False)\n",
+ "\n",
+ "checkpoint_path = f\"{work_dir}/token_level_classification/dev/checkpoints/checkpoint-step=49-consumed_samples=100.0-last\"\n",
+ "results_path = f\"{work_dir}/token_level_classification/infer/\""
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 15,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "%%capture --no-display --no-stderr cell_output\n",
+ "\n",
+ "! infer_esm2 --checkpoint-path {checkpoint_path} \\\n",
+ " --config-class ESM2FineTuneTokenConfig \\\n",
+ " --data-path {data_path} \\\n",
+ " --results-path {results_path} \\\n",
+ " --micro-batch-size 3 \\\n",
+ " --num-gpus 1 \\\n",
+ " --precision \"bf16-mixed\" \\\n",
+ " --include-embeddings \\\n",
+ " --include-hiddens \\\n",
+ " --include-input-ids"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "The inference results are written into a `.pt` file which can be loaded using PyTorch library:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 16,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "hidden_states\ttorch.Size([10, 1024, 320])\n",
+ "input_ids\ttorch.Size([10, 1024])\n",
+ "embeddings\ttorch.Size([10, 320])\n",
+ "classification_output\ttorch.Size([10, 1024, 3])\n"
+ ]
+ }
+ ],
+ "source": [
+ "import torch\n",
+ "results = torch.load(f\"{results_path}/predictions__rank_0.pt\")\n",
+ "\n",
+ "for key, val in results.items():\n",
+ " if val is not None:\n",
+ " print(f'{key}\\t{val.shape}')"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "We can use the label tokenizer to convert the classification output to class names. Note that for demonstration purposes we are using a small dataset of artificial sequences in this example. You may experience over-fitting and observe no change in the validation metrics. This amount of data and the short training run does not result in accurate predictions."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 17,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "from bionemo.esm2.data.tokenizer import get_tokenizer\n",
+ "\n",
+ "\n",
+ "tokenizer = get_tokenizer()\n",
+ "tokens = tokenizer.all_tokens\n",
+ "aa_tokens = ['L', 'A', 'G', 'V', 'S', 'E', 'R', 'T', 'I', 'D', 'P', 'K', 'Q', 'N', 'F', 'Y', 'M', 'H', 'W', 'C']\n",
+ "aa_indices = [i for i, token in enumerate(tokens) if token in aa_tokens]\n",
+ "extra_indices = [i for i, token in enumerate(tokens) if token not in aa_tokens]\n",
+ "\n",
+ "input_ids = results['input_ids'] # b, s\n",
+ "# mask where non-amino acid tokens are True\n",
+ "mask = ~torch.isin(input_ids, torch.tensor(extra_indices))"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 18,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "from bionemo.llm.data.label2id_tokenizer import Label2IDTokenizer\n",
+ "\n",
+ "label_tokenizer = Label2IDTokenizer()\n",
+ "label_tokenizer = label_tokenizer.build_vocab(secondary_structure_labels)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 19,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Predicted Secondary Structures:\n",
+ "HHHHEEEEECCCCCCCCCCCCCCCEEEHHHHHHEEECCCCCCCCCEEEEEEEEEHHH\n",
+ "EEEEECCCCCCCCCCCCCCEEEEECCCCCCEE\n",
+ "CCCCCEEEEECCCCCCCCCCCCCEEEECCCCCCCCCC\n",
+ "CCCCCCCCCCEEECCCCCEEEEEEEECCCCCCCCCCCCCCEEEEECCCCCCEE\n",
+ "ECCCCCCCCCCCCCCCEEEHHHHHHEEECCCCCCCCCEEEEEEEEEHHH\n",
+ "CCCCCCCCCCCCCECCCCCCCCCCCCEEEHHEEEHHHHEEHHHHHEE\n",
+ "CCCCCEEECCCCCEEEEEEEECCCCCCCCCCCCCCEEEEECCCCCCEE\n",
+ "EEEEECCCCCCCCCCCCCCEEEEECCCCCCEE\n",
+ "CCCCCECCCCCCCCCCCCEEEHHEEEHHHHEEHHHHHEE\n",
+ "EEEEEEEEEEEEEEEEEEEEEEEEEEHHHEEEEHHHECCCCCCCCCEEEEEEEEHHHEEEEEE\n"
+ ]
+ }
+ ],
+ "source": [
+ "output_ids = torch.argmax(results[\"classification_output\"], dim=-1)\n",
+ "\n",
+ "print(\"Predicted Secondary Structures:\")\n",
+ "for i in range(output_ids.shape[0]):\n",
+ " ss_ids = output_ids[i][mask[i]]\n",
+ " print(label_tokenizer.ids_to_text(ss_ids.tolist()))"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": []
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "Python 3",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.12.3"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 2
+}
diff --git a/docs/docs/user-guide/examples/bionemo-esm2/finetune.md b/docs/docs/user-guide/examples/bionemo-esm2/finetune.md
deleted file mode 100644
index 7968a22397..0000000000
--- a/docs/docs/user-guide/examples/bionemo-esm2/finetune.md
+++ /dev/null
@@ -1,263 +0,0 @@
-# ESM-2 Fine-Tuning
-
-This readme serves as a demo for implementing ESM-2 Fine-tuning module, running a regression example and using the model for inference.
-
-The ESM-2 model is a transformer-based protein language model that has achieved state-of-the-art results in various protein-related tasks. When fine-tuning ESM2, the task head plays a crucial role. A task head refers to the additional layer or set of layers added on top of a pre-trained model, like the ESM-2 transformer-based protein language model, to adapt it for a specific downstream task. As a part of transfer learning, a pre-trained model is often utilized to learn generic features from a large-scale dataset. However, these features might not be directly applicable to the specific task at hand. By incorporating a task head, which consists of learnable parameters, the model can adapt and specialize to the target task. The task head serves as a flexible and adaptable component that learns task-specific representations by leveraging the pre-trained features as a foundation. Through fine-tuning, the task head enables the model to learn and extract task-specific patterns, improving performance and addressing the nuances of the downstream task. It acts as a critical bridge between the pre-trained model and the specific task, enabling efficient and effective transfer of knowledge.
-
-
-# Setup and Assumptions
-
-In this tutorial, we will demonstrate how to create a fine-tune module, train a regression task head, and use the fine-tuned model for inference.
-
-All commands should be executed inside the BioNeMo docker container, which has all ESM-2 dependencies pre-installed. This tutorial assumes that a copy of the BioNeMo framework repo exists on workstation or server and has been mounted inside the container at `/workspace/bionemo2`. (**Note**: This `WORKDIR` may be `/workspaces/bionemo-framework` if you are using the VSCode Dev Container.) For more information on how to build or pull the BioNeMo2 container, refer to the [Access and Startup](../../getting-started/access-startup.md).
-
-To successfully accomplish this we need to define some key classes:
-
-1. Loss Reduction Method - To compute the supervised fine-tuning loss.
-2. Fine-Tuned Model Head - Downstream task head model.
-3. Fine-Tuned Model - Model that combines ESM-2 with the task head model.
-4. Fine-Tuning Config - Configures the fine-tuning model and loss to use in the training and inference framework.
-5. Dataset - Training and inference datasets for ESM2.
-
-## 1 - Loss Reduction Class
-
-A class for calculating the supervised loss of the fine-tune model from targets. We inherit from Megatron Bert Masked Language Model Loss (`BERTMLMLossWithReduction`) and override the `forward()` pass to compute MSE loss of the regression head within a micro-batch. The `reduce()` method is used for computing the average over the micro-batches and is only used for logging.
-
-```python
-class RegressorLossReduction(BERTMLMLossWithReduction):
- def forward(
- self, batch: Dict[str, torch.Tensor], forward_out: Dict[str, torch.Tensor]
- ) -> Tuple[torch.Tensor, Union[PerTokenLossDict, SameSizeLossDict]]:
-
- targets = batch["labels"] # [b, 1]
- regression_output = forward_out
- loss = torch.nn.functional.mse_loss(regression_output, targets)
- return loss, {"avg": loss}
-
- def reduce(self, losses_reduced_per_micro_batch: Sequence[ReductionT]) -> torch.Tensor:
- losses = torch.stack([loss["avg"] for loss in losses_reduced_per_micro_batch])
- return losses.mean()
-```
-
-## 2 - Fine-Tuned Model Head
-
-An MLP class for sequence-level regression. This class inherits `MegatronModule` and uses the fine-tune config (`TransformerConfig`) to configure the regression head for the fine-tuned ESM-2 model.
-
-```python
-class MegatronMLPHead(MegatronModule):
- def __init__(self, config: TransformerConfig):
- super().__init__(config)
- layer_sizes = [config.hidden_size, 256, 1]
- self.linear_layers = torch.nn.ModuleList(
- [torch.nn.Linear(i, o) for i, o in zip(layer_sizes[:-1], layer_sizes[1:])]
- )
- self.act = torch.nn.ReLU()
- self.dropout = torch.nn.Dropout(p=config.ft_dropout)
-
- def forward(self, hidden_states: torch.Tensor) -> List[torch.Tensor]:
- ...
-```
-
-## 3 - Fine-Tuned Model
-
-A fine-tuned ESM-2 model class for token classification tasks. This class inherits from the `ESM2Model` class and adds the custom regression head `MegatronMLPHead` the we created in the previous step. Optionally one can freeze all or parts of the encoder by parsing through the model parameters in the model constructor.
-
-```python
-class ESM2FineTuneSeqModel(ESM2Model):
- def __init__(self, config, *args, post_process: bool = True, return_embeddings: bool = False, **kwargs):
- super().__init__(config, *args, post_process=post_process, return_embeddings=True, **kwargs)
-
- # freeze encoder parameters
- if config.encoder_frozen:
- for _, param in self.named_parameters():
- param.requires_grad = False
-
- if post_process:
- self.regression_head = MegatronMLPHead(config)
-
- def forward(self, *args, **kwargs,):
- output = super().forward(*args, **kwargs)
- ...
- regression_output = self.regression_head(embeddings)
- return regression_output
-```
-
-## 4 - Fine-Tuning Config
-
-A `dataclass` that configures the fine-tuned ESM-2 model. In this example `ESM2FineTuneSeqConfig` inherits from `ESM2GenericConfig` and adds custom arguments to setup the fine-tuned model. The `configure_model()` method of this `dataclass` is called within the `Lightning` module to call the model constructor with the `dataclass` arguments.
-
-The common arguments among different fine-tuning tasks are
-
-- `model_cls`: The fine-tune model class (`ESM2FineTuneSeqModel`)
-- `initial_ckpt_path`: BioNeMo 2.0 ESM-2 pre-trained checkpoint
-- `initial_ckpt_skip_keys_with_these_prefixes`: skip keys when loading parameters from a checkpoint. Here we should not look for `regression_head` in the pre-trained checkpoint.
-- `get_loss_reduction_class()`: Implements selection of the appropriate `MegatronLossReduction` class, e.g. `bionemo.esm2.model.finetune.finetune_regressor.RegressorLossReduction`.
-
-```python
-
-@dataclass
-class ESM2FineTuneSeqConfig(ESM2GenericConfig[ESM2FineTuneSeqModel], iom.IOMixinWithGettersSetters):
- model_cls: Type[ESM2FineTuneSeqModel] = ESM2FineTuneSeqModel
- # The following checkpoint path is for nemo2 checkpoints. Config parameters not present in
- # self.override_parent_fields will be loaded from the checkpoint and override those values here.
- initial_ckpt_path: str | None = None
- # typical case is fine-tune the base biobert that doesn't have this head. If you are instead loading a checkpoint
- # that has this new head and want to keep using these weights, please drop this next line or set to []
- initial_ckpt_skip_keys_with_these_prefixes: List[str] = field(default_factory=lambda: ["regression_head"])
-
- encoder_frozen: bool = True # freeze encoder parameters
- ft_dropout: float = 0.25 # MLP layer dropout
-
- def get_loss_reduction_class(self) -> Type[MegatronLossReduction]:
- return RegressorLossReduction
-```
-
-## 5 - Dataset
-
-We will use a sample dataset for demonstration purposes. Create a dataset class by extending from ```torch.utils.data.Dataset```. For the purposes of this demo, we'll assume dataset consists of small set of protein sequences with a target value of `len(sequence) / 100.0` as their labels.
-
-```python
-data = [
- ("MVLSPADKTNVKAAWGKVGAHAGEYGAEALERH", 0.33),
- ...
-]
-```
-
-Therefore, the custom BioNeMo dataset class will be appropriate (found in ```bionemo.esm2.model.finetune.finetune_regressor.InMemorySingleValueDataset```) as it facilitates predicting on a single value. An excerpt from the class is shown below. This example dataset expected a sequence of `Tuple` that hold `(sequence, target)` values. However, one can simply extend ```InMemorySingleValueDataset``` class in a similar way to customize your class for your data.
-
-```python
-class InMemorySingleValueDataset(Dataset):
- def __init__(
- self,
- data: Sequence[Tuple[str, float]],
- tokenizer: tokenizer.BioNeMoESMTokenizer = tokenizer.get_tokenizer(),
- seed: int = np.random.SeedSequence().entropy,
- ):
-```
-
-For any arbitrary data file formats, user can process the data into a list of tuples containing (sequence, label) and use this dataset class. Or override the dataset class to load their custom data files.
-
-To coordinate the creation of training, validation and testing datasets from your data, we need to use a `datamodule` class. To do this we can directly use or extend the ```ESM2FineTuneDataModule``` class (located at ```bionemo.esm2.model.finetune.datamodule.ESM2FineTuneDataModule```) which defines helpful abstract methods that use your dataset class.
-
-```python
-dataset = InMemorySingleValueDataset(data)
-data_module = ESM2FineTuneDataModule(
- train_dataset=train_dataset,
- valid_dataset=valid_dataset
- micro_batch_size=4, # size of a batch to be processed in a device
- global_batch_size=8, # size of batch across all devices. Should be multiple of micro_batch_size
-)
-```
-
-# Fine-Tuning the Regressor Task Head for ESM2
-
-Now we can put these five requirements together to fine-tune a regressor task head starting from a pre-trained 650M ESM-2 model (`pretrain_ckpt_path`). We can take advantage of a simple training loop in ```bionemo.esm2.model.fnetune.train``` and use the ```train_model()`` function to start the fine-tuning process in the following.
-
-```python
-# create a List[Tuple] with (sequence, target) values
-artificial_sequence_data = [
- "TLILGWSDKLGSLLNQLAIANESLGGGTIAVMAERDKEDMELDIGKMEFDFKGTSVI",
- "LYSGDHSTQGARFLRDLAENTGRAEYELLSLF",
- "GRFNVWLGGNESKIRQVLKAVKEIGVSPTLFAVYEKN",
- "DELTALGGLLHDIGKPVQRAGLYSGDHSTQGARFLRDLAENTGRAEYELLSLF",
- "KLGSLLNQLAIANESLGGGTIAVMAERDKEDMELDIGKMEFDFKGTSVI",
- "LFGAIGNAISAIHGQSAVEELVDAFVGGARISSAFPYSGDTYYLPKP",
- "LGGLLHDIGKPVQRAGLYSGDHSTQGARFLRDLAENTGRAEYELLSLF",
- "LYSGDHSTQGARFLRDLAENTGRAEYELLSLF",
- "ISAIHGQSAVEELVDAFVGGARISSAFPYSGDTYYLPKP",
- "SGSKASSDSQDANQCCTSCEDNAPATSYCVECSEPLCETCVEAHQRVKYTKDHTVRSTGPAKT",
-]
-
-data = [(seq, len(seq)/100.0) for seq in artificial_sequence_data]
-
-# we are training and validating on the same dataset for simplicity
-dataset = InMemorySingleValueDataset(data)
-data_module = ESM2FineTuneDataModule(train_dataset=dataset, valid_dataset=dataset)
-
-experiment_name = "finetune_regressor"
-n_steps_train = 50
-seed = 42
-
-# To download a 650M pre-trained ESM2 model
-pretrain_ckpt_path = load("esm2/650m:2.0")
-
-config = ESM2FineTuneSeqConfig(
- initial_ckpt_path=str(pretrain_ckpt_path)
-)
-
-checkpoint, metrics, trainer = train_model(
- experiment_name=experiment_name,
- experiment_dir=Path(experiment_results_dir), # new checkpoint will land in a subdir of this
- config=config, # same config as before since we are just continuing training
- data_module=data_module,
- n_steps_train=n_steps_train,
-)
-```
-
-This example is fully implemented in ```bionemo.esm2.model.finetune.train``` and can be executed by:
-```bash
-python -m bionemo.esm2.model.finetune.train
-```
-
-## Notes
-1. The above example is fine-tuning a 650M ESM-2 model. The pre-trained checkpoints can be downloaded from NGC resources using either the following bash command or the `load` function in `bionemo.core.data.load` as shown above.
- ```bash
- download_bionemo_data esm2/650m:2.0
- ```
- and pass the output path (e.g. `.../.cache/bionemo/975d29ee980fcb08c97401bbdfdcf8ce-esm2_650M_nemo2.tar.gz.untar`) as an argument into `initial_ckpt_path` while setting the config object:
- ```python
- config = ESM2FineTuneSeqConfig(
- initial_ckpt_path=str(pretrain_ckpt_path)
- )
- ```
-2. Due to Megatron limitations, the log produced by the training run iterates on steps/iterations and not epochs. Therefore, `Training epoch` counter stays at value zero while `iteration` and `global_ste`p increase during the course of training (example in the following).
- ```bash
- Training epoch 0, iteration | ... | global_step: | reduced_train_loss: ... | val_loss: ...
- ```
- to achieve the same epoch-based effect while training, please choose the number of training steps (`n_steps_train`) so that:
- ```bash
- n_steps_train * global_batch_size = len(dataset) * desired_num_epochs
- ```
-3. We are using a small dataset of artificial sequences as our fine-tuning data in this example. You may experience over-fitting and observe no change in the validation metrics.
-
-# Fine-Tuned ESM-2 Model Inference
-Now we can use ```bionemo.esm2.model.finetune.train.infer``` to run inference on an example prediction dataset.
-Record the checkpoint path reported at the end of the finetuning run, after executing `python -m bionemo.esm2.model.finetune.train` (e.g. `/tmp/tmp1b5wlnba/finetune_regressor/checkpoints/finetune_regressor--reduced_train_loss=0.0016-epoch=0-last`) and use that as an argument to inference script (`--checkpoint-path`).
-
-We download a CSV example dataset of articical sequences for this inference example. Please refer to [ESM-2 Inference](./inference) tutorial for detailed explanation of the arguments and how to create your own CSV file.
-
-```bash
-mkdir -p $WORKDIR/esm2_finetune_tutorial
-
-# download sample data CSV for inference
-DATA_PATH=$(download_bionemo_data esm2/testdata_esm2_infer:2.0)
-RESULTS_PATH=$WORKDIR/esm2_finetune_tutorial/
-
-infer_esm2 --checkpoint-path \
- --data-path $DATA_PATH \
- --results-path $RESULTS_PATH \
- --config-class ESM2FineTuneSeqConfig
-```
-
-This will create a result `.pt` file under `$WORKDIR/esm2_finetune_tutorial/predictions__rank_0.pt` which can be loaded via PyTorch library in python environment:
-
-```python
-import torch
-
-# Set the path to results file e.g. /workspace/bionemo2/esm2_finetune_tutorial/predictions__rank_0.pt
-# results_path = /workspace/bionemo2/esm2_finetune_tutorial/predictions__rank_0.pt
-results = torch.load(results_path)
-
-# results is a python dict which includes the following result tensors for this example:
-# results['regression_output'] is a tensor with shape: torch.Size([10, 1])
-```
-
-## Notes
-- ESM2 Inference module takes the `--checkpoint-path` and `--config-class` arguments to create a config object by pointing the path in `initial_ckpt_path`. Since we need to load all the parameters from this checkpoint (and don't skip the head) we reset the `initial_ckpt_skip_keys_with_these_prefixes` in this config.
-
- ```python
- config = ESM2FineTuneSeqConfig(
- initial_ckpt_path = ,
- initial_ckpt_skip_keys_with_these_prefixes: List[str] = field(default_factory=list)
- )
- ```
diff --git a/docs/docs/user-guide/examples/bionemo-esm2/inference.ipynb b/docs/docs/user-guide/examples/bionemo-esm2/inference.ipynb
index 5dfd17964f..f487e73011 100644
--- a/docs/docs/user-guide/examples/bionemo-esm2/inference.ipynb
+++ b/docs/docs/user-guide/examples/bionemo-esm2/inference.ipynb
@@ -141,11 +141,40 @@
"execution_count": 4,
"metadata": {},
"outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "Downloading data from 'nvidia/clara/esm2nv650m:2.0' to file '/home/ubuntu/.cache/bionemo/0798767e843e3d54315aef91934d28ae7d8e93c2849d5fcfbdf5fac242013997-esm2_650M_nemo2.tar.gz'.\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "{\n",
+ " \"download_end\": \"2025-01-14 22:01:24\",\n",
+ " \"download_start\": \"2025-01-14 22:01:05\",\n",
+ " \"download_time\": \"18s\",\n",
+ " \"files_downloaded\": 1,\n",
+ " \"local_path\": \"/home/ubuntu/.cache/bionemo/tmpfj1e52vw/esm2nv650m_v2.0\",\n",
+ " \"size_downloaded\": \"1.12 GB\",\n",
+ " \"status\": \"COMPLETED\"\n",
+ "}\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "Untarring contents of '/home/ubuntu/.cache/bionemo/0798767e843e3d54315aef91934d28ae7d8e93c2849d5fcfbdf5fac242013997-esm2_650M_nemo2.tar.gz' to '/home/ubuntu/.cache/bionemo/0798767e843e3d54315aef91934d28ae7d8e93c2849d5fcfbdf5fac242013997-esm2_650M_nemo2.tar.gz.untar'\n"
+ ]
+ },
{
"name": "stdout",
"output_type": "stream",
"text": [
- "/home/bionemo/.cache/bionemo/0798767e843e3d54315aef91934d28ae7d8e93c2849d5fcfbdf5fac242013997-esm2_650M_nemo2.tar.gz.untar\n"
+ "/home/ubuntu/.cache/bionemo/0798767e843e3d54315aef91934d28ae7d8e93c2849d5fcfbdf5fac242013997-esm2_650M_nemo2.tar.gz.untar\n"
]
}
],
@@ -168,7 +197,7 @@
"metadata": {},
"source": [
"\n",
- "We use the `InMemoryCSVDataset` class to load the protein sequence data from a `.csv` file. This data file should at least have a `sequences` column and can optionally have a `labels` column used for fine-tuning applications. Here is an example of how to create your own inference input data using a list of sequences in python:"
+ "We use the `InMemoryProteinDataset` class to load the protein sequence data from a `.csv` file. This data file should at least have a `sequences` column and can optionally have a `labels` column used for fine-tuning applications. Here is an example of how to create your own inference input data using a list of sequences in python:"
]
},
{
@@ -238,12 +267,12 @@
"name": "stdout",
"output_type": "stream",
"text": [
- "2024-12-16 20:19:23 - faiss.loader - INFO - Loading faiss with AVX512 support.\n",
- "2024-12-16 20:19:23 - faiss.loader - INFO - Successfully loaded faiss with AVX512 support.\n",
- "[NeMo W 2024-12-16 20:19:24 nemo_logging:361] /usr/local/lib/python3.10/dist-packages/pydub/utils.py:170: RuntimeWarning: Couldn't find ffmpeg or avconv - defaulting to ffmpeg, but may not work\n",
+ "2025-01-14 22:01:45 - faiss.loader - INFO - Loading faiss with AVX512 support.\n",
+ "2025-01-14 22:01:45 - faiss.loader - INFO - Successfully loaded faiss with AVX512 support.\n",
+ "[NeMo W 2025-01-14 22:01:46 nemo_logging:405] /usr/local/lib/python3.12/dist-packages/pydub/utils.py:170: RuntimeWarning: Couldn't find ffmpeg or avconv - defaulting to ffmpeg, but may not work\n",
" warn(\"Couldn't find ffmpeg or avconv - defaulting to ffmpeg, but may not work\", RuntimeWarning)\n",
" \n",
- "[NeMo W 2024-12-16 20:19:24 nemo_logging:361] /usr/local/lib/python3.10/dist-packages/pyannote/core/notebook.py:134: MatplotlibDeprecationWarning: The get_cmap function was deprecated in Matplotlib 3.7 and will be removed two minor releases later. Use ``matplotlib.colormaps[name]`` or ``matplotlib.colormaps.get_cmap(obj)`` instead.\n",
+ "[NeMo W 2025-01-14 22:01:46 nemo_logging:405] /usr/local/lib/python3.12/dist-packages/pyannote/core/notebook.py:134: MatplotlibDeprecationWarning: The get_cmap function was deprecated in Matplotlib 3.7 and will be removed in 3.11. Use ``matplotlib.colormaps[name]`` or ``matplotlib.colormaps.get_cmap()`` or ``pyplot.get_cmap()`` instead.\n",
" cm = get_cmap(\"Set1\")\n",
" \n",
"usage: infer_esm2 [-h] --checkpoint-path CHECKPOINT_PATH --data-path DATA_PATH\n",
@@ -533,7 +562,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
- "version": "3.10.12"
+ "version": "3.12.3"
}
},
"nbformat": 4,
diff --git a/docs/docs/user-guide/getting-started/development.md b/docs/docs/user-guide/getting-started/development.md
index ce97a78cbe..ae8a0997f7 100644
--- a/docs/docs/user-guide/getting-started/development.md
+++ b/docs/docs/user-guide/getting-started/development.md
@@ -136,7 +136,7 @@ of the model. The fine-tuning steps will be application-specific, but a general
6. **Run inference**: Once the model is fine-tuned, use it to make predictions on new, unseen data.
For more information on fine-tuning a model, refer to the [ESM-2 Fine-tuning
-Tutorial](../examples/bionemo-esm2/finetune.md).
+Tutorial](../examples/bionemo-esm2/finetune.ipynb).
## Advanced Developer Documentation
diff --git a/sub-packages/bionemo-esm2/pyproject.toml b/sub-packages/bionemo-esm2/pyproject.toml
index aa8f7715ed..4acce854fa 100644
--- a/sub-packages/bionemo-esm2/pyproject.toml
+++ b/sub-packages/bionemo-esm2/pyproject.toml
@@ -22,6 +22,7 @@ bionemo-esm2-train= "bionemo.esm2.run.main:main"
bionemo-esm2-recipe= "bionemo.esm2.run.recipes:main"
infer_esm2 = "bionemo.esm2.scripts.infer_esm2:infer_esm2_entrypoint"
train_esm2 = "bionemo.esm2.scripts.train_esm2:train_esm2_entrypoint"
+finetune_esm2 = "bionemo.esm2.scripts.finetune_esm2:finetune_esm2_entrypoint"
# Make sure that the tokenizer files are included along with the python files during installation.
[tool.setuptools.package-data]
diff --git a/sub-packages/bionemo-esm2/src/bionemo/esm2/model/finetune/datamodule.py b/sub-packages/bionemo-esm2/src/bionemo/esm2/model/finetune/datamodule.py
index 09526572ef..7104f64373 100644
--- a/sub-packages/bionemo-esm2/src/bionemo/esm2/model/finetune/datamodule.py
+++ b/sub-packages/bionemo-esm2/src/bionemo/esm2/model/finetune/datamodule.py
@@ -15,119 +15,27 @@
import functools
-import os
-from typing import Literal, Sequence, Tuple, Union
+from typing import Literal, Union
-import numpy as np
-import pandas as pd
-import torch
-import torch.utils.data
from lightning.pytorch.utilities.types import EVAL_DATALOADERS, TRAIN_DATALOADERS
from nemo.lightning.data import WrappedDataLoader
from nemo.lightning.pytorch.plugins import MegatronDataSampler
from nemo.utils import logging
-from torch import Tensor
-from torch.utils.data import Dataset
from bionemo.core.data.multi_epoch_dataset import IdentityMultiEpochDatasetWrapper, MultiEpochDatasetResampler
from bionemo.esm2.data import tokenizer
-from bionemo.esm2.model.finetune.finetune_regressor import InMemorySingleValueDataset
-from bionemo.esm2.model.finetune.finetune_token_classifier import InMemoryPerTokenValueDataset
+from bionemo.esm2.model.finetune.dataset import (
+ InMemoryPerTokenValueDataset,
+ InMemoryProteinDataset,
+ InMemorySingleValueDataset,
+)
from bionemo.llm.data import collate
from bionemo.llm.data.datamodule import MegatronDataModule
-from bionemo.llm.data.types import BertSample
from bionemo.llm.utils.datamodule_utils import infer_num_samples
Mode = Literal["train", "validation", "test", "predict"]
-
-
-class InMemoryCSVDataset(Dataset):
- """An in-memory dataset that tokenize strings into BertSample instances."""
-
- def __init__(
- self,
- data_path: str | os.PathLike,
- tokenizer: tokenizer.BioNeMoESMTokenizer = tokenizer.get_tokenizer(),
- seed: int = np.random.SeedSequence().entropy, # type: ignore
- ):
- """Initializes a dataset for single-value regression fine-tuning.
-
- This is an in-memory dataset that does not apply masking to the sequence. But keeps track of in the
- dataset sequences provided.
-
- Args:
- data_path (str | os.PathLike): A path to the CSV file containing sequences.
- labels (Optional[Sequence[float | str]]): An optional sequence of labels with 1:1 mapping to sequences.
- tokenizer (tokenizer.BioNeMoESMTokenizer, optional): The tokenizer to use. Defaults to tokenizer.get_tokenizer().
- seed: Random seed for reproducibility. This seed is mixed with the index of the sample to retrieve to ensure
- that __getitem__ is deterministic, but can be random across different runs. If None, a random seed is
- generated.
- """
- self.sequences, self.labels = self.load_data(data_path)
-
- self.seed = seed
- self._len = len(self.sequences)
- self.tokenizer = tokenizer
-
- def __len__(self) -> int:
- """The size of the dataset."""
- return self._len
-
- def __getitem__(self, index: int) -> BertSample:
- """Obtains the BertSample at the given index."""
- sequence = self.sequences[index]
- tokenized_sequence = self._tokenize(sequence)
-
- label = tokenized_sequence if len(self.labels) == 0 else torch.Tensor([self.labels[index]])
- # Overall mask for a token being masked in some capacity - either mask token, random token, or left as-is
- loss_mask = ~torch.isin(tokenized_sequence, Tensor(self.tokenizer.all_special_ids))
-
- return {
- "text": tokenized_sequence,
- "types": torch.zeros_like(tokenized_sequence, dtype=torch.int64),
- "attention_mask": torch.ones_like(tokenized_sequence, dtype=torch.int64),
- "labels": label,
- "loss_mask": loss_mask,
- "is_random": torch.zeros_like(tokenized_sequence, dtype=torch.int64),
- }
-
- def load_data(self, csv_path: str | os.PathLike) -> Tuple[Sequence, Sequence]:
- """Loads data from a CSV file, returning sequences and optionally labels.
-
- This method should be implemented by subclasses to process labels for their specific dataset.
-
- Args:
- csv_path (str | os.PathLike): The path to the CSV file containing the data.
- The file is expected to have at least one column named 'sequence'. A 'label' column is optional.
-
- Returns:
- Tuple[Sequence, Sequence]: A tuple where the first element is a list of sequences and the second element is
- a list of labels. If the 'label' column is not present, an empty list is returned for labels.
- """
- df = pd.read_csv(csv_path)
- sequences = df["sequences"].tolist()
-
- if "labels" in df.columns:
- labels = df["labels"].tolist()
- else:
- labels = []
- return sequences, labels
-
- def _tokenize(self, sequence: str) -> Tensor:
- """Tokenize a protein sequence.
-
- Args:
- sequence: The protein sequence.
-
- Returns:
- The tokenized sequence.
- """
- tensor = self.tokenizer.encode(sequence, add_special_tokens=True, return_tensors="pt")
- return tensor.flatten() # type: ignore
-
-
-DATASET_TYPES = Union[InMemoryPerTokenValueDataset, InMemorySingleValueDataset, InMemoryCSVDataset, None]
+DATASET_TYPES = Union[InMemoryPerTokenValueDataset, InMemorySingleValueDataset, InMemoryProteinDataset, None]
class ESM2FineTuneDataModule(MegatronDataModule):
diff --git a/sub-packages/bionemo-esm2/src/bionemo/esm2/model/finetune/dataset.py b/sub-packages/bionemo-esm2/src/bionemo/esm2/model/finetune/dataset.py
new file mode 100644
index 0000000000..542854548d
--- /dev/null
+++ b/sub-packages/bionemo-esm2/src/bionemo/esm2/model/finetune/dataset.py
@@ -0,0 +1,221 @@
+# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: LicenseRef-Apache2
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import os
+from typing import Sequence
+
+import numpy as np
+import pandas as pd
+import torch
+import torch.utils.data
+from torch import Tensor
+from torch.utils.data import Dataset
+
+from bionemo.esm2.data import tokenizer
+from bionemo.llm.data.collate import MLM_LOSS_IGNORE_INDEX
+from bionemo.llm.data.label2id_tokenizer import Label2IDTokenizer
+from bionemo.llm.data.types import BertSample
+
+
+__all__: Sequence[str] = (
+ "InMemoryProteinDataset",
+ "InMemorySingleValueDataset",
+ "InMemoryPerTokenValueDataset",
+)
+
+
+class InMemoryProteinDataset(Dataset):
+ """An in-memory dataset that tokenize strings into BertSample instances."""
+
+ def __init__(
+ self,
+ sequences: pd.Series,
+ labels: pd.Series | None = None,
+ tokenizer: tokenizer.BioNeMoESMTokenizer = tokenizer.get_tokenizer(),
+ seed: int = np.random.SeedSequence().entropy, # type: ignore
+ ):
+ """Initializes a dataset of protein sequences.
+
+ This is an in-memory dataset that does not apply masking to the sequence. But keeps track of in the
+ dataset sequences provided.
+
+ Args:
+ sequences (pd.Series): A pandas Series containing protein sequences.
+ labels (pd.Series, optional): A pandas Series containing labels. Defaults to None.
+ tokenizer (tokenizer.BioNeMoESMTokenizer, optional): The tokenizer to use. Defaults to tokenizer.get_tokenizer().
+ seed: Random seed for reproducibility. This seed is mixed with the index of the sample to retrieve to ensure
+ that __getitem__ is deterministic, but can be random across different runs. If None, a random seed is
+ generated.
+ """
+ self.sequences = sequences
+ self.labels = labels
+
+ self.seed = seed
+ self._len = len(self.sequences)
+ self.tokenizer = tokenizer
+
+ @classmethod
+ def from_csv(
+ cls, csv_path: str | os.PathLike, tokenizer: tokenizer.BioNeMoESMTokenizer = tokenizer.get_tokenizer()
+ ):
+ """Class method to create a ProteinDataset instance from a CSV file."""
+ df = pd.read_csv(csv_path)
+
+ # Validate presence of required columns
+ if "sequences" not in df.columns:
+ raise KeyError("The CSV must contain a 'sequences' column.")
+
+ sequences = df["sequences"]
+ labels = df["labels"] if "labels" in df.columns else None
+ return cls(sequences, labels, tokenizer)
+
+ def __len__(self) -> int:
+ """The size of the dataset."""
+ return self._len
+
+ def __getitem__(self, index: int) -> BertSample:
+ """Obtains the BertSample at the given index."""
+ sequence = self.sequences[index]
+ tokenized_sequence = self._tokenize(sequence)
+
+ label = tokenized_sequence if self.labels is None else self.transform_label(self.labels.iloc[index])
+ # Overall mask for a token being masked in some capacity - either mask token, random token, or left as-is
+ loss_mask = ~torch.isin(tokenized_sequence, Tensor(self.tokenizer.all_special_ids))
+
+ return {
+ "text": tokenized_sequence,
+ "types": torch.zeros_like(tokenized_sequence, dtype=torch.int64),
+ "attention_mask": torch.ones_like(tokenized_sequence, dtype=torch.int64),
+ "labels": label,
+ "loss_mask": loss_mask,
+ "is_random": torch.zeros_like(tokenized_sequence, dtype=torch.int64),
+ }
+
+ def _tokenize(self, sequence: str) -> Tensor:
+ """Tokenize a protein sequence.
+
+ Args:
+ sequence: The protein sequence.
+
+ Returns:
+ The tokenized sequence.
+ """
+ tensor = self.tokenizer.encode(sequence, add_special_tokens=True, return_tensors="pt")
+ return tensor.flatten() # type: ignore
+
+ def transform_label(self, label):
+ """Transform the label.
+
+ This method should be implemented by subclass if label needs additional transformation.
+
+ Args:
+ label: label to be transformed
+
+ Returns:
+ transformed_label
+ """
+ return label
+
+
+class InMemorySingleValueDataset(InMemoryProteinDataset):
+ """An in-memory dataset that tokenizes strings into BertSample instances."""
+
+ def __init__(
+ self,
+ sequences: pd.Series,
+ labels: pd.Series | None = None,
+ tokenizer: tokenizer.BioNeMoESMTokenizer = tokenizer.get_tokenizer(),
+ seed: int = np.random.SeedSequence().entropy, # type: ignore
+ ):
+ """Initializes a dataset for single-value regression fine-tuning.
+
+ This is an in-memory dataset that does not apply masking to the sequence. But keeps track of in the
+ dataset sequences provided.
+
+ Args:
+ sequences (pd.Series): A pandas Series containing protein sequences.
+ labels (pd.Series, optional): A pandas Series containing labels. Defaults to None.
+ tokenizer (tokenizer.BioNeMoESMTokenizer, optional): The tokenizer to use. Defaults to tokenizer.get_tokenizer().
+ seed: Random seed for reproducibility. This seed is mixed with the index of the sample to retrieve to ensure
+ that __getitem__ is deterministic, but can be random across different runs. If None, a random seed is
+ generated.
+ """
+ super().__init__(sequences, labels, tokenizer, seed)
+
+ def transform_label(self, label: float) -> Tensor:
+ """Transform the regression label.
+
+ Args:
+ label: regression value
+
+ Returns:
+ tokenized label
+ """
+ return torch.tensor([label], dtype=torch.float)
+
+
+class InMemoryPerTokenValueDataset(InMemoryProteinDataset):
+ """An in-memory dataset of labeled strings, which are tokenized on demand."""
+
+ def __init__(
+ self,
+ sequences: pd.Series,
+ labels: pd.Series | None = None,
+ tokenizer: tokenizer.BioNeMoESMTokenizer = tokenizer.get_tokenizer(),
+ seed: int = np.random.SeedSequence().entropy, # type: ignore
+ ):
+ """Initializes a dataset for per-token classification fine-tuning.
+
+ This is an in-memory dataset that does not apply masking to the sequence. But keeps track of in the
+ dataset sequences provided.
+
+ Args:
+ sequences (pd.Series): A pandas Series containing protein sequences.
+ labels (pd.Series, optional): A pandas Series containing labels. Defaults to None.
+ tokenizer (tokenizer.BioNeMoESMTokenizer, optional): The tokenizer to use. Defaults to tokenizer.get_tokenizer().
+ seed: Random seed for reproducibility. This seed is mixed with the index of the sample to retrieve to ensure
+ that __getitem__ is deterministic, but can be random across different runs. If None, a random seed is
+ generated.
+ """
+ super().__init__(sequences, labels, tokenizer, seed)
+ label_tokenizer = Label2IDTokenizer()
+ self.label_tokenizer = label_tokenizer.build_vocab("CHE")
+ self.label_cls_eos_id = MLM_LOSS_IGNORE_INDEX
+
+ def transform_label(self, label: str) -> Tensor:
+ """Transform the sequence label by tokenizing them.
+
+ This method tokenizes the secondary structure token sequences.
+
+ Args:
+ label: secondary structure token sequences to be transformed
+
+ Returns:
+ tokenized label
+ """
+ label_ids = torch.tensor(self.label_tokenizer.text_to_ids(label))
+
+ # # for multi-label classification with BCEWithLogitsLoss
+ # tokenized_labels = torch.nn.functional.one_hot(label_ids, num_classes=self.label_tokenizer.vocab_size)
+ # cls_eos = torch.full((1, self.label_tokenizer.vocab_size), self.label_cls_eos_id, dtype=tokenized_labels.dtype)
+
+ # for multi-class (mutually exclusive) classification with CrossEntropyLoss
+ tokenized_labels = label_ids
+ cls_eos = torch.tensor([self.label_cls_eos_id], dtype=tokenized_labels.dtype)
+
+ # add cls / eos label ids with padding value -100 to have the same shape as tokenized_sequence
+ labels = torch.cat((cls_eos, tokenized_labels, cls_eos))
+ return labels
diff --git a/sub-packages/bionemo-esm2/src/bionemo/esm2/model/finetune/finetune_regressor.py b/sub-packages/bionemo-esm2/src/bionemo/esm2/model/finetune/finetune_regressor.py
index f63a194190..27d3375864 100644
--- a/sub-packages/bionemo-esm2/src/bionemo/esm2/model/finetune/finetune_regressor.py
+++ b/sub-packages/bionemo-esm2/src/bionemo/esm2/model/finetune/finetune_regressor.py
@@ -17,17 +17,13 @@
from dataclasses import dataclass, field
from typing import Dict, List, Sequence, Tuple, Type
-import numpy as np
import torch
from megatron.core import parallel_state
from megatron.core.transformer.module import MegatronModule
from megatron.core.transformer.transformer_config import TransformerConfig
from torch import Tensor
-from torch.utils.data import Dataset
from bionemo.esm2.api import ESM2GenericConfig, ESM2Model
-from bionemo.esm2.data import tokenizer
-from bionemo.llm.data.types import BertSample
from bionemo.llm.model.biobert.model import BioBertOutput
from bionemo.llm.model.loss import BERTMLMLossWithReduction, PerTokenLossDict, SameSizeLossDict
from bionemo.llm.utils import iomixin_utils as iom
@@ -41,7 +37,6 @@
"MegatronMLPHead",
"ESM2FineTuneSeqModel",
"ESM2FineTuneSeqConfig",
- "InMemorySingleValueDataset",
)
@@ -178,61 +173,3 @@ class ESM2FineTuneSeqConfig(
def get_loss_reduction_class(self) -> Type[RegressorLossReduction]:
"""Returns RegressorLossReduction class."""
return RegressorLossReduction
-
-
-class InMemorySingleValueDataset(Dataset):
- """An in-memory dataset that tokenizes strings into BertSample instances."""
-
- def __init__(
- self,
- data: Sequence[Tuple[str, float]],
- tokenizer: tokenizer.BioNeMoESMTokenizer = tokenizer.get_tokenizer(),
- seed: int = np.random.SeedSequence().entropy, # type: ignore
- ):
- """Initializes a dataset for single-value regression fine-tuning.
-
- This is an in-memory dataset that does not apply masking to the sequence.
-
- Args:
- data (Sequence[Tuple[str, float]]): A sequence of tuples containing the sequence and target data.
- tokenizer (tokenizer.BioNeMoESMTokenizer, optional): The tokenizer to use. Defaults to tokenizer.get_tokenizer().
- seed: Random seed for reproducibility. This seed is mixed with the index of the sample to retrieve to ensure
- that __getitem__ is deterministic, but can be random across different runs. If None, a random seed is
- generated.
- """
- self.data = data
- self.seed = seed
- self._len = len(self.data)
- self.tokenizer = tokenizer
-
- def __len__(self) -> int:
- """The size of the dataset."""
- return self._len
-
- def __getitem__(self, index: int) -> BertSample:
- """Obtains the BertSample at the given index."""
- sequence, target = self.data[index]
- tokenized_sequence = self._tokenize(sequence)
- # Overall mask for a token being masked in some capacity - either mask token, random token, or left as-is
- loss_mask = ~torch.isin(tokenized_sequence, Tensor(self.tokenizer.all_special_ids))
-
- return {
- "text": tokenized_sequence,
- "types": torch.zeros_like(tokenized_sequence, dtype=torch.int64),
- "attention_mask": torch.ones_like(tokenized_sequence, dtype=torch.int64),
- "labels": torch.tensor([target], dtype=torch.float),
- "loss_mask": loss_mask,
- "is_random": torch.zeros_like(tokenized_sequence, dtype=torch.int64),
- }
-
- def _tokenize(self, sequence: str) -> Tensor:
- """Tokenize a protein sequence.
-
- Args:
- sequence: The protein sequence.
-
- Returns:
- The tokenized sequence.
- """
- tensor = self.tokenizer.encode(sequence, add_special_tokens=True, return_tensors="pt")
- return tensor.flatten() # type: ignore
diff --git a/sub-packages/bionemo-esm2/src/bionemo/esm2/model/finetune/finetune_token_classifier.py b/sub-packages/bionemo-esm2/src/bionemo/esm2/model/finetune/finetune_token_classifier.py
index b0991669d8..fe67cf2ac8 100644
--- a/sub-packages/bionemo-esm2/src/bionemo/esm2/model/finetune/finetune_token_classifier.py
+++ b/sub-packages/bionemo-esm2/src/bionemo/esm2/model/finetune/finetune_token_classifier.py
@@ -15,21 +15,15 @@
from dataclasses import dataclass, field
-from typing import List, Sequence, Tuple, Type, TypedDict
+from typing import Dict, List, Sequence, Tuple, Type
-import numpy as np
import torch
from megatron.core import parallel_state
from megatron.core.transformer.module import MegatronModule
from megatron.core.transformer.transformer_config import TransformerConfig
from torch import Tensor
-from torch.utils.data import Dataset
from bionemo.esm2.api import ESM2GenericConfig, ESM2Model
-from bionemo.esm2.data import tokenizer
-from bionemo.llm.data.collate import MLM_LOSS_IGNORE_INDEX
-from bionemo.llm.data.label2id_tokenizer import Label2IDTokenizer
-from bionemo.llm.data.types import BertSample
from bionemo.llm.model.biobert.model import BioBertOutput
from bionemo.llm.model.loss import BERTMLMLossWithReduction, PerTokenLossDict, SameSizeLossDict
from bionemo.llm.utils import iomixin_utils as iom
@@ -44,25 +38,9 @@
"MegatronConvNetHead",
"ESM2FineTuneTokenModel",
"ESM2FineTuneTokenConfig",
- "InMemoryPerTokenValueDataset",
- "ClassifierInput",
- "Esm2FineTuneTokenOutput",
)
-class ClassifierInput(TypedDict):
- """Used as input in the ClassifierLossReduction's forward method."""
-
- labels: Tensor
- loss_mask: Tensor
-
-
-class Esm2FineTuneTokenOutput(BioBertOutput):
- """Inference output from ESM2FineTuneTokenModel."""
-
- classification_output: Tensor
-
-
class ClassifierLossReduction(BERTMLMLossWithReduction):
"""A class for calculating the cross entropy loss of classification output.
@@ -70,7 +48,7 @@ class ClassifierLossReduction(BERTMLMLossWithReduction):
"""
def forward(
- self, batch: ClassifierInput, forward_out: Esm2FineTuneTokenOutput
+ self, batch: Dict[str, Tensor], forward_out: Dict[str, Tensor]
) -> Tuple[Tensor, PerTokenLossDict | SameSizeLossDict]:
"""Calculates the loss within a micro-batch. A micro-batch is a batch of data on a single GPU.
@@ -159,9 +137,9 @@ def __init__(self, config, *args, include_hiddens: bool = False, post_process: b
# if we are doing post process (eg pipeline last stage) then we need to add the output layers
self.classification_head = MegatronConvNetHead(config)
- def forward(self, *args, **kwargs) -> Tensor | BioBertOutput | Esm2FineTuneTokenOutput:
+ def forward(self, *args, **kwargs) -> Tensor | BioBertOutput:
"""Inference."""
- output: Tensor | BioBertOutput | Esm2FineTuneTokenOutput = super().forward(*args, **kwargs)
+ output = super().forward(*args, **kwargs)
# Stop early if we are not in post_process mode (for example if we are in the middle of model parallelism)
if not self.post_process:
return output # we are not at the last pipeline stage so just return what the parent has
@@ -203,80 +181,3 @@ class ESM2FineTuneTokenConfig(
def get_loss_reduction_class(self) -> Type[ClassifierLossReduction]:
"""The loss function type."""
return ClassifierLossReduction
-
-
-class InMemoryPerTokenValueDataset(Dataset):
- """An in-memory dataset of labeled strings, which are tokenized on demand."""
-
- def __init__(
- self,
- data: Sequence[Tuple[str, str]],
- tokenizer: tokenizer.BioNeMoESMTokenizer = tokenizer.get_tokenizer(),
- seed: int = np.random.SeedSequence().entropy, # type: ignore
- ):
- """Initializes a dataset for per-token classification fine-tuning.
-
- This is an in-memory dataset that does not apply masking to the sequence.
-
- Args:
- data: A sequence of tuples containing the sequence and target data.
- tokenizer: The tokenizer to use. Defaults to tokenizer.get_tokenizer().
- seed: Random seed for reproducibility. This seed is mixed with the index of the sample to retrieve to
- ensure that __getitem__ is deterministic, but can be random across different runs. If None, a random
- seed is generated.
- """
- self.data = data
- self.seed = seed
- self._len = len(self.data)
- self.tokenizer = tokenizer
- label_tokenizer = Label2IDTokenizer()
- self.label_tokenizer = label_tokenizer.build_vocab("CHE")
- self.label_cls_eos_id = MLM_LOSS_IGNORE_INDEX
-
- def __len__(self) -> int:
- """Length of dataset."""
- return self._len
-
- def __getitem__(self, index: int) -> BertSample:
- """Gets a BertSample associated to the supplied index."""
- sequence, target = self.data[index]
- tokenized_sequence = self._tokenize(sequence)
- # Overall mask for a token being masked in some capacity - either mask token, random token, or left as-is
- loss_mask = ~torch.isin(tokenized_sequence, torch.tensor(self.tokenizer.all_special_ids))
- labels = self._tokenize_labels(target)
-
- return {
- "text": tokenized_sequence,
- "types": torch.zeros_like(tokenized_sequence, dtype=torch.int64),
- "attention_mask": torch.ones_like(tokenized_sequence, dtype=torch.int64),
- "labels": labels,
- "loss_mask": loss_mask,
- "is_random": torch.zeros_like(tokenized_sequence, dtype=torch.int64),
- }
-
- def _tokenize_labels(self, labels_sequence: str) -> Tensor:
- label_ids = torch.tensor(self.label_tokenizer.text_to_ids(labels_sequence))
-
- # # for multi-label classification with BCEWithLogitsLoss
- # tokenized_labels = torch.nn.functional.one_hot(label_ids, num_classes=self.label_tokenizer.vocab_size)
- # cls_eos = torch.full((1, self.label_tokenizer.vocab_size), self.label_cls_eos_id, dtype=tokenized_labels.dtype)
-
- # for multi-class (mutually exclusive) classification with CrossEntropyLoss
- tokenized_labels = label_ids
- cls_eos = torch.tensor([self.label_cls_eos_id], dtype=tokenized_labels.dtype)
-
- # add cls / eos label ids with padding value -100 to have the same shape as tokenized_sequence
- labels = torch.cat((cls_eos, tokenized_labels, cls_eos))
- return labels
-
- def _tokenize(self, sequence: str) -> Tensor:
- """Tokenize a protein sequence.
-
- Args:
- sequence: The protein sequence.
-
- Returns:
- The tokenized sequence.
- """
- tensor = self.tokenizer.encode(sequence, add_special_tokens=True, return_tensors="pt")
- return tensor.flatten() # type: ignore
diff --git a/sub-packages/bionemo-esm2/src/bionemo/esm2/model/finetune/train.py b/sub-packages/bionemo-esm2/src/bionemo/esm2/model/finetune/train.py
deleted file mode 100644
index 638729e3f4..0000000000
--- a/sub-packages/bionemo-esm2/src/bionemo/esm2/model/finetune/train.py
+++ /dev/null
@@ -1,189 +0,0 @@
-# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
-# SPDX-License-Identifier: LicenseRef-Apache2
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-
-import tempfile
-from pathlib import Path
-from typing import Sequence, Tuple
-
-import lightning.pytorch as pl
-from lightning.pytorch.callbacks import Callback, RichModelSummary
-from lightning.pytorch.loggers import TensorBoardLogger
-from megatron.core.optimizer.optimizer_config import OptimizerConfig
-from nemo import lightning as nl
-from nemo.collections import llm as nllm
-from nemo.lightning import resume
-from nemo.lightning.nemo_logger import NeMoLogger
-from nemo.lightning.pytorch import callbacks as nl_callbacks
-from nemo.lightning.pytorch.callbacks.model_transform import ModelTransform
-from nemo.lightning.pytorch.callbacks.peft import PEFT
-from nemo.lightning.pytorch.optim.megatron import MegatronOptimizerModule
-
-from bionemo.core.data.load import load
-from bionemo.esm2.api import ESM2GenericConfig
-from bionemo.esm2.data.tokenizer import BioNeMoESMTokenizer, get_tokenizer
-from bionemo.esm2.model.finetune.datamodule import ESM2FineTuneDataModule
-from bionemo.esm2.model.finetune.finetune_regressor import ESM2FineTuneSeqConfig, InMemorySingleValueDataset
-from bionemo.llm.model.biobert.lightning import biobert_lightning_module
-
-
-__all__: Sequence[str] = ("train_model",)
-
-
-def train_model(
- experiment_name: str,
- experiment_dir: Path,
- config: ESM2GenericConfig,
- data_module: pl.LightningDataModule,
- n_steps_train: int,
- metric_tracker: Callback | None = None,
- tokenizer: BioNeMoESMTokenizer = get_tokenizer(),
- peft: PEFT | None = None,
- _use_rich_model_summary: bool = True,
-) -> Tuple[Path, Callback | None, nl.Trainer]:
- """Trains a BioNeMo ESM2 model using PyTorch Lightning.
-
- Parameters:
- experiment_name: The name of the experiment.
- experiment_dir: The directory where the experiment will be saved.
- config: The configuration for the ESM2 model.
- data_module: The data module for training and validation.
- n_steps_train: The number of training steps.
- metric_tracker: Optional callback to track metrics
- tokenizer: The tokenizer to use. Defaults to `get_tokenizer()`.
- peft: The PEFT (Parameter-Efficient Fine-Tuning) module. Defaults to None.
- _use_rich_model_summary: Whether to use the RichModelSummary callback, omitted in our test suite until
- https://nvbugspro.nvidia.com/bug/4959776 is resolved. Defaults to True.
-
- Returns:
- A tuple containing the path to the saved checkpoint, a MetricTracker
- object, and the PyTorch Lightning Trainer object.
- """
- checkpoint_callback = nl_callbacks.ModelCheckpoint(
- save_last=True,
- save_on_train_epoch_end=True,
- monitor="reduced_train_loss", # TODO find out how to get val_loss logged and use "val_loss",
- every_n_train_steps=n_steps_train // 2,
- always_save_context=True, # Enables the .nemo file-like checkpointing where all IOMixins are under SerDe
- )
-
- # Setup the logger and train the model
- nemo_logger = NeMoLogger(
- log_dir=str(experiment_dir),
- name=experiment_name,
- tensorboard=TensorBoardLogger(save_dir=experiment_dir, name=experiment_name),
- ckpt=checkpoint_callback,
- )
- # Needed so that the trainer can find an output directory for the profiler
- # ckpt_path needs to be a string for SerDe
- optimizer = MegatronOptimizerModule(
- config=OptimizerConfig(
- lr=5e-4,
- optimizer="adam",
- use_distributed_optimizer=True,
- fp16=config.fp16,
- bf16=config.bf16,
- )
- )
- module = biobert_lightning_module(config=config, tokenizer=tokenizer, optimizer=optimizer, model_transform=peft)
-
- strategy = nl.MegatronStrategy(
- tensor_model_parallel_size=1,
- pipeline_model_parallel_size=1,
- ddp="megatron",
- find_unused_parameters=True,
- enable_nemo_ckpt_io=True,
- )
-
- if _use_rich_model_summary:
- # RichModelSummary is not used in the test suite until https://nvbugspro.nvidia.com/bug/4959776 is resolved due
- # to errors with serialization / deserialization.
- callbacks: list[Callback] = [RichModelSummary(max_depth=4)]
- else:
- callbacks = []
-
- if metric_tracker is not None:
- callbacks.append(metric_tracker)
- if peft is not None:
- callbacks.append(
- ModelTransform()
- ) # Callback needed for PEFT fine-tuning using NeMo2, i.e. biobert_lightning_module(model_transform=peft).
-
- trainer = nl.Trainer(
- accelerator="gpu",
- devices=1,
- strategy=strategy,
- limit_val_batches=2,
- val_check_interval=n_steps_train // 2,
- max_steps=n_steps_train,
- num_nodes=1,
- log_every_n_steps=n_steps_train // 2,
- callbacks=callbacks,
- plugins=nl.MegatronMixedPrecision(precision="bf16-mixed"),
- )
- nllm.train(
- model=module,
- data=data_module,
- trainer=trainer,
- log=nemo_logger,
- resume=resume.AutoResume(
- resume_if_exists=True, # Looks for the -last checkpoint to continue training.
- resume_ignore_no_checkpoint=True, # When false this will throw an error with no existing checkpoint.
- ),
- )
- ckpt_path = Path(checkpoint_callback.last_model_path.replace(".ckpt", ""))
- return ckpt_path, metric_tracker, trainer
-
-
-if __name__ == "__main__":
- # set the results directory
- experiment_results_dir = tempfile.TemporaryDirectory().name
-
- # create a List[Tuple] with (sequence, target) values
- artificial_sequence_data = [
- "TLILGWSDKLGSLLNQLAIANESLGGGTIAVMAERDKEDMELDIGKMEFDFKGTSVI",
- "LYSGDHSTQGARFLRDLAENTGRAEYELLSLF",
- "GRFNVWLGGNESKIRQVLKAVKEIGVSPTLFAVYEKN",
- "DELTALGGLLHDIGKPVQRAGLYSGDHSTQGARFLRDLAENTGRAEYELLSLF",
- "KLGSLLNQLAIANESLGGGTIAVMAERDKEDMELDIGKMEFDFKGTSVI",
- "LFGAIGNAISAIHGQSAVEELVDAFVGGARISSAFPYSGDTYYLPKP",
- "LGGLLHDIGKPVQRAGLYSGDHSTQGARFLRDLAENTGRAEYELLSLF",
- "LYSGDHSTQGARFLRDLAENTGRAEYELLSLF",
- "ISAIHGQSAVEELVDAFVGGARISSAFPYSGDTYYLPKP",
- "SGSKASSDSQDANQCCTSCEDNAPATSYCVECSEPLCETCVEAHQRVKYTKDHTVRSTGPAKT",
- ]
- data = [(seq, len(seq) / 100.0) for seq in artificial_sequence_data]
-
- # we are training and validating on the same dataset for simplicity
- dataset = InMemorySingleValueDataset(data)
- data_module = ESM2FineTuneDataModule(train_dataset=dataset, valid_dataset=dataset)
-
- experiment_name = "finetune_regressor"
- n_steps_train = 50
- seed = 42
-
- # To download a 650M pre-trained ESM2 model
- pretrain_ckpt_path = load("esm2/650m:2.0")
-
- config = ESM2FineTuneSeqConfig(initial_ckpt_path=str(pretrain_ckpt_path))
-
- checkpoint, metrics, trainer = train_model(
- experiment_name=experiment_name,
- experiment_dir=Path(experiment_results_dir), # new checkpoint will land in a subdir of this
- config=config, # same config as before since we are just continuing training
- data_module=data_module,
- n_steps_train=n_steps_train,
- )
- print(f"Experiment completed with checkpoint stored at {checkpoint}")
diff --git a/sub-packages/bionemo-esm2/src/bionemo/esm2/scripts/finetune_esm2.py b/sub-packages/bionemo-esm2/src/bionemo/esm2/scripts/finetune_esm2.py
new file mode 100644
index 0000000000..1b35f169ff
--- /dev/null
+++ b/sub-packages/bionemo-esm2/src/bionemo/esm2/scripts/finetune_esm2.py
@@ -0,0 +1,635 @@
+# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: LicenseRef-Apache2
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import argparse
+from pathlib import Path
+from typing import Dict, List, Optional, Sequence, Tuple, Type, get_args
+
+from lightning.pytorch.callbacks import Callback, LearningRateMonitor, RichModelSummary
+from megatron.core.distributed import DistributedDataParallelConfig
+from megatron.core.optimizer import OptimizerConfig
+from nemo import lightning as nl
+from nemo.collections import llm
+from nemo.lightning import resume
+from nemo.lightning.pytorch import callbacks as nl_callbacks
+from nemo.lightning.pytorch.optim import MegatronOptimizerModule
+
+from bionemo.core.utils.dtypes import PrecisionTypes, get_autocast_dtype
+from bionemo.esm2.data.tokenizer import get_tokenizer
+from bionemo.esm2.model.finetune.datamodule import ESM2FineTuneDataModule
+from bionemo.esm2.model.finetune.dataset import (
+ InMemoryPerTokenValueDataset,
+ InMemoryProteinDataset,
+ InMemorySingleValueDataset,
+)
+from bionemo.esm2.model.finetune.finetune_regressor import ESM2FineTuneSeqConfig
+from bionemo.esm2.model.finetune.finetune_token_classifier import ESM2FineTuneTokenConfig
+from bionemo.llm.model.biobert.lightning import biobert_lightning_module
+from bionemo.llm.model.biobert.model import BioBertConfig
+from bionemo.llm.utils.datamodule_utils import float_or_int_or_none, infer_global_batch_size
+from bionemo.llm.utils.logger_utils import WandbConfig, setup_nemo_lightning_logger
+
+
+__all__: Sequence[str] = ("train_model", "finetune_esm2_entrypoint", "get_parser")
+
+
+SUPPORTED_CONFIGS = {
+ "ESM2FineTuneSeqConfig": ESM2FineTuneSeqConfig,
+ "ESM2FineTuneTokenConfig": ESM2FineTuneTokenConfig,
+}
+
+SUPPORTED_DATASETS = {
+ "InMemoryProteinDataset": InMemoryProteinDataset,
+ "InMemorySingleValueDataset": InMemorySingleValueDataset,
+ "InMemoryPerTokenValueDataset": InMemoryPerTokenValueDataset,
+}
+
+
+def train_model(
+ train_data_path: Path,
+ valid_data_path: Path,
+ num_nodes: int,
+ devices: int,
+ min_seq_length: Optional[int],
+ max_seq_length: int,
+ result_dir: Path,
+ num_steps: int,
+ limit_val_batches: int,
+ val_check_interval: int,
+ log_every_n_steps: Optional[int],
+ num_dataset_workers: int,
+ lr: float,
+ micro_batch_size: int,
+ accumulate_grad_batches: int,
+ experiment_name: str,
+ resume_if_exists: bool,
+ precision: PrecisionTypes,
+ wandb_entity: Optional[str] = None,
+ wandb_project: Optional[str] = None,
+ wandb_offline: bool = False,
+ wandb_tags: Optional[List[str]] = None,
+ wandb_group: Optional[str] = None,
+ wandb_id: Optional[str] = None,
+ wandb_anonymous: Optional[bool] = False,
+ wandb_log_model: bool = False,
+ pipeline_model_parallel_size: int = 1,
+ tensor_model_parallel_size: int = 1,
+ create_tensorboard_logger: bool = False,
+ restore_from_checkpoint_path: Optional[str] = None,
+ save_last_checkpoint: bool = True,
+ metric_to_monitor_for_checkpoints: str = "val_loss",
+ save_top_k: int = 2,
+ nsys_profiling: bool = False,
+ nsys_start_step: int = 0,
+ nsys_end_step: Optional[int] = None,
+ nsys_ranks: List[int] = [0],
+ dataset_class: Type[InMemoryProteinDataset] = InMemorySingleValueDataset,
+ config_class: Type[BioBertConfig] = ESM2FineTuneSeqConfig,
+ metric_tracker: Callback | None = None,
+ overlap_grad_reduce: bool = True,
+ overlap_param_gather: bool = True,
+ average_in_collective: bool = True,
+ grad_reduce_in_fp32: bool = False,
+) -> Tuple[Path, Callback | None, nl.Trainer]:
+ """Train an ESM2 model on UR data.
+
+ Args:
+ train_data_path (Path): path to train CSV
+ valid_data_path (Path): path to validation CSV
+ num_nodes (int): Number of nodes to run on
+ devices (int): number of devices
+ min_seq_length (Optional[int]): minimum sequence length
+ max_seq_length (int): maximum sequence length
+ result_dir (Path): directory to store results, logs and checkpoints
+ num_steps (int): number of steps to train the model for
+ limit_val_batches (int): limit the number of validation global batches to this many
+ val_check_interval (int): number of steps to periodically check the validation loss
+ log_every_n_steps (Optional[int]): log every n steps
+ num_dataset_workers (int): number of dataset workers
+ lr (float): learning rate
+ micro_batch_size (int): micro batch size, from this and parallelism settings we infer the global batch size
+ accumulate_grad_batches (int): number of batches to accumulate gradients for
+ experiment_name (str): experiment name, this is the name used for the wandb run, and the sub-directory of the
+ result_dir that stores the logs and checkpoints.
+ resume_if_exists (bool): attempt to resume if the checkpoint exists [FIXME @skothenhill this doesn't work yet]
+ precision (PrecisionTypes): Precision type for training (e.g., float16, float32)
+ wandb_entity (Optional[str]): The team posting this run (default: your username or your default team)
+ wandb_project (Optional[str]): The name of the project to which this run will belong
+ wandb_offline (bool): Run offline (data can be streamed later to wandb servers).
+ wandb_tags (Optional[List[str]]): Tags associated with this run
+ wandb_group (Optional[str]): A unique string shared by all runs in a given group
+ wandb_id (Optional[str]): Sets the version, mainly used to resume a previous run
+ wandb_anonymous (Optional[bool]): Enables or explicitly disables anonymous logging
+ wandb_log_model (bool): Save checkpoints in wandb dir to upload on W&B servers
+ pipeline_model_parallel_size (int): pipeline model parallel size
+ tensor_model_parallel_size (int): tensor model parallel size
+ create_tensorboard_logger (bool): create the tensorboard logger
+ restore_from_checkpoint_path (Optional[str]): If set, restores the model from the directory passed in. Expects the
+ checkpoint to be created by using the ModelCheckpoint class and always_save_context=True.
+ save_last_checkpoint (bool): whether to save the last checkpoint
+ metric_to_monitor_for_checkpoints (str): metric to monitor for checkpoints
+ save_top_k (int): number of top checkpoints to save
+ nsys_profiling (bool): whether to enable nsys profiling
+ nsys_start_step (int): start step for nsys profiling
+ nsys_end_step (Optional[int]): end step for nsys profiling
+ nsys_ranks (List[int]): ranks for nsys profiling
+ dataset_class (Type[InMemoryProteinDataset]): The dataset class for loading the data from a CSV file
+ config_class (Type[BioBertConfig]): The config class for configuring the model using checkpoint provided
+ metric_tracker: Optional callback to track metrics (used for testing)
+ overlap_grad_reduce (bool): overlap gradient reduction
+ overlap_param_gather (bool): overlap parameter gather
+ average_in_collective (bool): average in collective
+ grad_reduce_in_fp32 (bool): gradient reduction in fp32
+ """
+ # Create the result directory if it does not exist.
+ result_dir.mkdir(parents=True, exist_ok=True)
+
+ # Setup the strategy and trainer
+ global_batch_size = infer_global_batch_size(
+ micro_batch_size=micro_batch_size,
+ num_nodes=num_nodes,
+ devices=devices,
+ accumulate_grad_batches=accumulate_grad_batches,
+ tensor_model_parallel_size=tensor_model_parallel_size,
+ pipeline_model_parallel_size=pipeline_model_parallel_size,
+ )
+
+ strategy = nl.MegatronStrategy(
+ tensor_model_parallel_size=tensor_model_parallel_size,
+ pipeline_model_parallel_size=pipeline_model_parallel_size,
+ ddp=DistributedDataParallelConfig(
+ check_for_nan_in_grad=True,
+ overlap_grad_reduce=overlap_grad_reduce,
+ overlap_param_gather=overlap_param_gather,
+ average_in_collective=average_in_collective,
+ grad_reduce_in_fp32=grad_reduce_in_fp32,
+ use_distributed_optimizer=True,
+ ),
+ find_unused_parameters=True,
+ gradient_as_bucket_view=True,
+ ckpt_include_optimizer=True,
+ ckpt_async_save=True,
+ ckpt_parallel_load=True,
+ )
+
+ # for wandb integration
+ # Please refer to https://pytorch-lightning.readthedocs.io/en/0.7.6/api/lightning.pytorch.loggers.html"
+ wandb_config: Optional[WandbConfig] = (
+ None
+ if wandb_project is None
+ else WandbConfig(
+ offline=wandb_offline,
+ project=wandb_project,
+ entity=wandb_entity,
+ tags=wandb_tags,
+ group=wandb_group,
+ id=wandb_id,
+ anonymous=wandb_anonymous,
+ log_model=wandb_log_model,
+ )
+ )
+
+ callbacks = [
+ RichModelSummary(max_depth=4),
+ LearningRateMonitor(),
+ nl_callbacks.PreemptionCallback(),
+ ]
+ if metric_tracker is not None:
+ callbacks.append(metric_tracker)
+ if nsys_profiling:
+ if nsys_end_step is None:
+ nsys_end_step = num_steps
+ callbacks.append(
+ nl_callbacks.NsysCallback(
+ start_step=nsys_start_step, end_step=nsys_end_step, ranks=nsys_ranks, gen_shape=True
+ )
+ )
+
+ trainer = nl.Trainer(
+ devices=devices,
+ max_steps=num_steps,
+ accelerator="gpu",
+ strategy=strategy,
+ limit_val_batches=limit_val_batches, # This controls upsampling and downsampling
+ val_check_interval=val_check_interval,
+ log_every_n_steps=log_every_n_steps,
+ num_nodes=num_nodes,
+ callbacks=callbacks,
+ plugins=nl.MegatronMixedPrecision(
+ precision=precision,
+ params_dtype=get_autocast_dtype(precision),
+ pipeline_dtype=get_autocast_dtype(precision),
+ grad_reduce_in_fp32=grad_reduce_in_fp32,
+ autocast_enabled=False,
+ ),
+ )
+
+ tokenizer = get_tokenizer()
+
+ # Initialize the data module.
+ train_dataset = dataset_class.from_csv(train_data_path)
+ valid_dataset = dataset_class.from_csv(valid_data_path)
+
+ data_module = ESM2FineTuneDataModule(
+ train_dataset=train_dataset,
+ valid_dataset=valid_dataset,
+ global_batch_size=global_batch_size,
+ micro_batch_size=micro_batch_size,
+ min_seq_length=min_seq_length,
+ max_seq_length=max_seq_length,
+ num_workers=num_dataset_workers,
+ tokenizer=tokenizer,
+ )
+ # Configure the model
+ config = config_class(
+ params_dtype=get_autocast_dtype(precision),
+ pipeline_dtype=get_autocast_dtype(precision),
+ autocast_dtype=get_autocast_dtype(precision), # setting this speeds things up a lot
+ tensor_model_parallel_size=tensor_model_parallel_size,
+ pipeline_model_parallel_size=pipeline_model_parallel_size,
+ initial_ckpt_path=str(restore_from_checkpoint_path),
+ # initial_ckpt_skip_keys_with_these_prefixes=[], # load everything from the checkpoint.
+ )
+
+ optimizer = MegatronOptimizerModule(
+ config=OptimizerConfig(
+ lr=lr,
+ optimizer="adam", # fused_adam not supported
+ use_distributed_optimizer=True,
+ weight_decay=0.01,
+ adam_beta1=0.9,
+ adam_beta2=0.98,
+ )
+ )
+
+ module = biobert_lightning_module(config=config, tokenizer=tokenizer, optimizer=optimizer)
+
+ # Configure our custom Checkpointer
+ checkpoint_callback = nl_callbacks.ModelCheckpoint(
+ save_last=save_last_checkpoint,
+ monitor=metric_to_monitor_for_checkpoints, # "val_loss",
+ save_top_k=save_top_k,
+ every_n_train_steps=val_check_interval,
+ always_save_context=True, # Enables the .nemo file-like checkpointing where all IOMixins are under SerDe
+ filename="checkpoint-{step}-{consumed_samples}", # Including step and consumed_samples in the checkpoint filename prevents duplicate filenames and bugs related to this.
+ )
+
+ # Setup the logger and train the model
+ nemo_logger = setup_nemo_lightning_logger(
+ root_dir=result_dir,
+ name=experiment_name,
+ initialize_tensorboard_logger=create_tensorboard_logger,
+ wandb_config=wandb_config,
+ ckpt_callback=checkpoint_callback,
+ )
+
+ llm.train(
+ model=module,
+ data=data_module,
+ trainer=trainer,
+ log=nemo_logger,
+ resume=resume.AutoResume(
+ resume_if_exists=resume_if_exists, # Looks for the -last checkpoint to continue training.
+ resume_ignore_no_checkpoint=True, # When false this will throw an error with no existing checkpoint.
+ ),
+ )
+ ckpt_path = Path(checkpoint_callback.last_model_path.replace(".ckpt", ""))
+ return ckpt_path, metric_tracker, trainer
+
+
+def finetune_esm2_entrypoint():
+ """Entrypoint for running ESM2 finetuning."""
+ # 1. get arguments
+ parser = get_parser()
+ args = parser.parse_args()
+ # 2. Call pretrain with args
+ train_model(
+ train_data_path=args.train_data_path,
+ valid_data_path=args.valid_data_path,
+ num_nodes=args.num_nodes,
+ devices=args.num_gpus,
+ min_seq_length=args.min_seq_length,
+ max_seq_length=args.max_seq_length,
+ result_dir=args.result_dir,
+ wandb_entity=args.wandb_entity,
+ wandb_project=args.wandb_project,
+ wandb_tags=args.wandb_tags,
+ wandb_group=args.wandb_group,
+ wandb_id=args.wandb_id,
+ wandb_anonymous=args.wandb_anonymous,
+ wandb_log_model=args.wandb_log_model,
+ wandb_offline=args.wandb_offline,
+ num_steps=args.num_steps,
+ limit_val_batches=args.limit_val_batches,
+ val_check_interval=args.val_check_interval,
+ log_every_n_steps=args.log_every_n_steps,
+ num_dataset_workers=args.num_dataset_workers,
+ lr=args.lr,
+ micro_batch_size=args.micro_batch_size,
+ pipeline_model_parallel_size=args.pipeline_model_parallel_size,
+ tensor_model_parallel_size=args.tensor_model_parallel_size,
+ accumulate_grad_batches=args.accumulate_grad_batches,
+ precision=args.precision,
+ experiment_name=args.experiment_name,
+ resume_if_exists=args.resume_if_exists,
+ restore_from_checkpoint_path=args.restore_from_checkpoint_path,
+ save_last_checkpoint=args.save_last_checkpoint,
+ metric_to_monitor_for_checkpoints=args.metric_to_monitor_for_checkpoints,
+ save_top_k=args.save_top_k,
+ nsys_profiling=args.nsys_profiling,
+ nsys_start_step=args.nsys_start_step,
+ nsys_end_step=args.nsys_end_step,
+ nsys_ranks=args.nsys_ranks,
+ dataset_class=args.dataset_class,
+ config_class=args.config_class,
+ overlap_grad_reduce=not args.no_overlap_grad_reduce,
+ overlap_param_gather=not args.no_overlap_param_gather,
+ average_in_collective=not args.no_average_in_collective,
+ grad_reduce_in_fp32=args.grad_reduce_in_fp32,
+ )
+
+
+def get_parser():
+ """Return the cli parser for this tool."""
+ # TODO migrate to hydra config
+ # Parse the arguments and pull them out into local variables for ease of future refactor to a
+ # config management system.
+ parser = argparse.ArgumentParser(description="Pretrain ESM2 with UR data.")
+ parser.add_argument(
+ "--train-data-path",
+ type=Path,
+ required=True,
+ help="Path to the train data CSV file",
+ )
+ parser.add_argument(
+ "--valid-data-path",
+ type=Path,
+ required=True,
+ help="Path to the valid data CSV file",
+ )
+ parser.add_argument(
+ "--precision",
+ type=str,
+ choices=get_args(PrecisionTypes),
+ required=False,
+ default="bf16-mixed",
+ help="Precision type to use for training.",
+ )
+ parser.add_argument(
+ "--lr",
+ type=float,
+ required=False,
+ default=4e-4,
+ help="Learning rate for training. Default is 4e-4",
+ )
+ parser.add_argument(
+ "--create-tensorboard-logger", action="store_true", default=False, help="Create a tensorboard logger."
+ )
+ # FIXME (@skothenhill) figure out how checkpointing and resumption should work with the new nemo trainer
+ parser.add_argument(
+ "--resume-if-exists", action="store_true", default=False, help="Resume training if a checkpoint exists."
+ )
+ parser.add_argument(
+ "--result-dir", type=Path, required=False, default=Path("./results"), help="Path to the result directory."
+ )
+ parser.add_argument("--experiment-name", type=str, required=False, default="esm2", help="Name of the experiment.")
+
+ parser.add_argument("--wandb-entity", type=str, default=None, help="The team posting this run")
+ parser.add_argument("--wandb-project", type=str, default=None, help="Wandb project name ")
+ parser.add_argument("--wandb-tags", nargs="+", type=str, default=None, help="Tags associated with this run")
+ parser.add_argument(
+ "--wandb-group", type=str, default=None, help="A unique string shared by all runs in a given group"
+ )
+ parser.add_argument(
+ "--wandb-id", type=str, default=None, help="Sets the version, mainly used to resume a previous run"
+ )
+ parser.add_argument(
+ "--wandb-anonymous", action="store_true", help="Enable or explicitly disable anonymous logging"
+ )
+ parser.add_argument(
+ "--wandb-log-model", action="store_true", help="Save checkpoints in wandb dir to upload on W&B servers"
+ )
+ parser.add_argument("--wandb-offline", action="store_true", help="Use wandb in offline mode")
+ parser.add_argument(
+ "--num-gpus",
+ type=int,
+ required=False,
+ default=1,
+ help="Number of GPUs to use for training. Default is 1.",
+ )
+ parser.add_argument(
+ "--num-nodes",
+ type=int,
+ required=False,
+ default=1,
+ help="Number of nodes to use for training. Default is 1.",
+ )
+ parser.add_argument(
+ "--num-steps",
+ type=int,
+ required=False,
+ default=500000,
+ help="Number of steps to use for training. Default is 500000.",
+ )
+ parser.add_argument(
+ "--num-dataset-workers",
+ type=int,
+ required=False,
+ default=1,
+ help="Number of workers to use for training. Default is 1.",
+ )
+ parser.add_argument(
+ "--val-check-interval",
+ type=int,
+ required=False,
+ default=10000,
+ help="Number of steps between validation. Default is 10000.",
+ )
+ parser.add_argument(
+ "--log-every-n-steps",
+ type=int,
+ required=False,
+ help="Number of steps between logging. Default is 50.",
+ )
+ parser.add_argument(
+ "--min-seq-length",
+ type=float_or_int_or_none,
+ required=False,
+ default=1024,
+ help="Minimum sequence length. Sampled will be padded if less than this value. Set 'None' to unset minimum.",
+ )
+ parser.add_argument(
+ "--max-seq-length",
+ type=int,
+ required=False,
+ default=1024,
+ help="Maximum sequence length. Samples will be truncated if exceeds this value.",
+ )
+ parser.add_argument(
+ "--limit-val-batches",
+ type=float_or_int_or_none,
+ required=False,
+ default=2,
+ help="Number of global batches used for validation if int. Fraction of validation dataset if float. Default is 2.",
+ )
+ parser.add_argument(
+ "--micro-batch-size",
+ type=int,
+ required=False,
+ default=64,
+ help="Micro-batch size. Global batch size is inferred from this.",
+ )
+ parser.add_argument(
+ "--pipeline-model-parallel-size",
+ type=int,
+ required=False,
+ default=1,
+ help="Pipeline model parallel size. Default is 1.",
+ )
+ parser.add_argument(
+ "--tensor-model-parallel-size",
+ type=int,
+ required=False,
+ default=1,
+ help="Tensor model parallel size. Default is 1.",
+ )
+ parser.add_argument(
+ "--accumulate-grad-batches",
+ type=int,
+ required=False,
+ default=1,
+ help="Gradient accumulation steps. Global batch size is inferred from this.",
+ )
+ parser.add_argument(
+ "--save-last-checkpoint",
+ action="store_true",
+ default=True,
+ help="Save the last checkpoint.",
+ )
+ parser.add_argument(
+ "--metric-to-monitor-for-checkpoints",
+ type=str,
+ required=False,
+ default="val_loss",
+ help="The metric to monitor for checkpointing.",
+ )
+ parser.add_argument(
+ "--save-top-k",
+ type=int,
+ required=False,
+ default=2,
+ help="Save the top k checkpoints.",
+ )
+ parser.add_argument(
+ "--restore-from-checkpoint-path",
+ type=Path,
+ required=False,
+ default=None,
+ help="Path to the checkpoint directory to restore from. Will override `--resume-if-exists` when set.",
+ )
+ parser.add_argument(
+ "--nsys-profiling",
+ action="store_true",
+ default=False,
+ help="Enable targeted `nsys` profiling on the training loop for a defined step range. To actually get profiling output you must run the whole program with `nsys`. For example: "
+ " `nsys profile -s none -o output_report_name -t cuda,nvtx --force-overwrite true --capture-range=cudaProfilerApi --capture-range-end=stop [regular python command here]`",
+ )
+ # start, end, rank
+ parser.add_argument(
+ "--nsys-start-step",
+ type=int,
+ required=False,
+ default=0,
+ help="Start nsys profiling after this step.",
+ )
+ parser.add_argument(
+ "--nsys-end-step",
+ type=int,
+ required=False,
+ help="End nsys profiling after this step.",
+ )
+ # rank as list of integers
+ parser.add_argument(
+ "--nsys-ranks",
+ type=int,
+ nargs="+",
+ required=False,
+ default=[0],
+ help="Enable nsys profiling for these ranks.",
+ )
+ # DDP config
+ parser.add_argument(
+ "--no-overlap-grad-reduce",
+ action="store_true",
+ default=False,
+ )
+ parser.add_argument(
+ "--no-overlap-param-gather",
+ action="store_true",
+ default=False,
+ )
+ parser.add_argument(
+ "--no-average-in-collective",
+ action="store_true",
+ default=False,
+ )
+ parser.add_argument(
+ "--grad-reduce-in-fp32",
+ action="store_true",
+ default=False,
+ )
+
+ config_class_options: Dict[str, Type[BioBertConfig]] = SUPPORTED_CONFIGS
+
+ def config_class_type(desc: str) -> Type[BioBertConfig]:
+ try:
+ return config_class_options[desc]
+ except KeyError:
+ raise argparse.ArgumentTypeError(
+ f"Do not recognize key {desc}, valid options are: {config_class_options.keys()}"
+ )
+
+ parser.add_argument(
+ "--config-class",
+ type=config_class_type,
+ default=ESM2FineTuneSeqConfig,
+ help="Model configs link model classes with losses, and handle model initialization (including from a prior "
+ "checkpoint). This is how you can fine-tune a model. First train with one config class that points to one model "
+ "class and loss, then implement and provide an alternative config class that points to a variant of that model "
+ "and alternative loss. In the future this script should also provide similar support for picking different data "
+ f"modules for fine-tuning with different data types. Choices: {config_class_options.keys()}",
+ )
+
+ dataset_class_options: Dict[str, Type[InMemoryProteinDataset]] = SUPPORTED_DATASETS
+
+ def dataset_class_type(desc: str) -> Type[InMemoryProteinDataset]:
+ try:
+ return dataset_class_options[desc]
+ except KeyError:
+ raise argparse.ArgumentTypeError(
+ f"Do not recognize key {desc}, valid options are: {dataset_class_options.keys()}"
+ )
+
+ parser.add_argument(
+ "--dataset-class",
+ type=dataset_class_type,
+ default=InMemorySingleValueDataset,
+ help=f"Dataset class name for finetuning. Choices: {config_class_options.keys()}",
+ )
+ return parser
+
+
+if __name__ == "__main__":
+ finetune_esm2_entrypoint()
diff --git a/sub-packages/bionemo-esm2/src/bionemo/esm2/scripts/infer_esm2.py b/sub-packages/bionemo-esm2/src/bionemo/esm2/scripts/infer_esm2.py
index bdfaa4fe6b..9531165c13 100644
--- a/sub-packages/bionemo-esm2/src/bionemo/esm2/scripts/infer_esm2.py
+++ b/sub-packages/bionemo-esm2/src/bionemo/esm2/scripts/infer_esm2.py
@@ -23,7 +23,8 @@
from bionemo.core.utils.dtypes import PrecisionTypes, get_autocast_dtype
from bionemo.esm2.api import ESM2Config
from bionemo.esm2.data.tokenizer import get_tokenizer
-from bionemo.esm2.model.finetune.datamodule import ESM2FineTuneDataModule, InMemoryCSVDataset
+from bionemo.esm2.model.finetune.datamodule import ESM2FineTuneDataModule
+from bionemo.esm2.model.finetune.dataset import InMemoryProteinDataset
from bionemo.esm2.model.finetune.finetune_regressor import ESM2FineTuneSeqConfig
from bionemo.esm2.model.finetune.finetune_token_classifier import ESM2FineTuneTokenConfig
from bionemo.llm.model.biobert.lightning import biobert_lightning_module
@@ -110,7 +111,7 @@ def infer_model(
plugins=nl.MegatronMixedPrecision(precision=precision),
)
- dataset = InMemoryCSVDataset(data_path=data_path)
+ dataset = InMemoryProteinDataset.from_csv(data_path)
datamodule = ESM2FineTuneDataModule(
predict_dataset=dataset,
micro_batch_size=micro_batch_size,
diff --git a/sub-packages/bionemo-esm2/tests/bionemo/esm2/conftest.py b/sub-packages/bionemo-esm2/tests/bionemo/esm2/conftest.py
index 2e0c7a0f7f..fc896b54a2 100644
--- a/sub-packages/bionemo-esm2/tests/bionemo/esm2/conftest.py
+++ b/sub-packages/bionemo-esm2/tests/bionemo/esm2/conftest.py
@@ -87,3 +87,14 @@ def dummy_data_single_value_regression_ft(dummy_data_per_token_classification_ft
"""
data = [(seq, len(seq) / 100.0) for seq, _ in dummy_data_per_token_classification_ft]
return data
+
+
+@pytest.fixture
+def dummy_protein_sequences(dummy_data_per_token_classification_ft):
+ """Fixture providing dummy data for per-token classification fine-tuning.
+
+ Returns:
+ list: A list of dummy data for per-token classification fine-tuning.
+ """
+ data = [seq for seq, _ in dummy_data_per_token_classification_ft]
+ return data
diff --git a/sub-packages/bionemo-esm2/tests/bionemo/esm2/model/finetune/test_datamodule.py b/sub-packages/bionemo-esm2/tests/bionemo/esm2/model/finetune/test_datamodule.py
new file mode 100644
index 0000000000..c1ccfb5284
--- /dev/null
+++ b/sub-packages/bionemo-esm2/tests/bionemo/esm2/model/finetune/test_datamodule.py
@@ -0,0 +1,81 @@
+# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: LicenseRef-Apache2
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import pandas as pd
+import pytest
+from torch.utils.data import DataLoader
+
+from bionemo.esm2.model.finetune.datamodule import ESM2FineTuneDataModule
+from bionemo.esm2.model.finetune.dataset import InMemoryProteinDataset
+
+
+@pytest.fixture
+def dummy_protein_csv(tmp_path, dummy_protein_sequences):
+ """Create a mock protein dataset."""
+ csv_file = tmp_path / "protein_dataset.csv"
+ # Create a DataFrame
+ df = pd.DataFrame(dummy_protein_sequences, columns=["sequences"])
+
+ # Save the DataFrame to a CSV file
+ df.to_csv(csv_file, index=False)
+ return csv_file
+
+
+@pytest.fixture
+def dataset(dummy_protein_csv):
+ return InMemoryProteinDataset.from_csv(dummy_protein_csv)
+
+
+@pytest.fixture
+def data_module(dataset):
+ return ESM2FineTuneDataModule(predict_dataset=dataset)
+
+
+def test_in_memory_csv_dataset(dataset):
+ assert len(dataset) > 0
+ sample = dataset[0]
+ assert isinstance(sample, dict)
+ assert "text" in sample
+ assert "labels" in sample
+
+
+def test_esm2_fine_tune_data_module_init(data_module):
+ assert data_module.train_dataset is None
+ assert data_module.valid_dataset is None
+ assert data_module.predict_dataset is not None
+
+
+def test_esm2_fine_tune_data_module_predict_dataloader(data_module):
+ predict_dataloader = data_module.predict_dataloader()
+ assert isinstance(predict_dataloader, DataLoader)
+ batch = next(iter(predict_dataloader))
+ assert isinstance(batch, dict)
+ assert "text" in batch
+
+
+def test_esm2_fine_tune_data_module_setup(data_module):
+ with pytest.raises(RuntimeError):
+ data_module.setup("fit")
+
+
+def test_esm2_fine_tune_data_module_train_dataloader(data_module):
+ with pytest.raises(AttributeError):
+ data_module.train_dataloader()
+
+
+def test_esm2_fine_tune_data_module_val_dataloader(data_module):
+ with pytest.raises(AttributeError):
+ data_module.val_dataloader()
diff --git a/sub-packages/bionemo-esm2/tests/bionemo/esm2/model/finetune/test_dataset.py b/sub-packages/bionemo-esm2/tests/bionemo/esm2/model/finetune/test_dataset.py
new file mode 100644
index 0000000000..afcd53feab
--- /dev/null
+++ b/sub-packages/bionemo-esm2/tests/bionemo/esm2/model/finetune/test_dataset.py
@@ -0,0 +1,168 @@
+# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: LicenseRef-Apache2
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import pandas as pd
+import pytest
+import torch
+from torch import Tensor
+
+from bionemo.esm2.model.finetune.dataset import (
+ InMemoryPerTokenValueDataset,
+ InMemoryProteinDataset,
+ InMemorySingleValueDataset,
+)
+from bionemo.llm.data.collate import MLM_LOSS_IGNORE_INDEX
+from bionemo.llm.data.label2id_tokenizer import Label2IDTokenizer
+
+
+def data_to_csv(data, tmp_path, with_label=True):
+ """Create a mock protein dataset."""
+ csv_file = tmp_path / "protein_dataset.csv"
+ # Create a DataFrame
+ df = pd.DataFrame(data, columns=["sequences", "labels"] if with_label else ["sequences"])
+
+ # Save the DataFrame to a CSV file
+ df.to_csv(csv_file, index=False)
+ return csv_file
+
+
+@pytest.fixture
+def dataset_no_labels(dummy_protein_sequences, tmp_path):
+ csv_path = data_to_csv(dummy_protein_sequences, tmp_path, with_label=False)
+ return InMemoryProteinDataset.from_csv(csv_path)
+
+
+@pytest.fixture
+def dataset_regression_labels(dummy_data_single_value_regression_ft, tmp_path):
+ csv_path = data_to_csv(dummy_data_single_value_regression_ft, tmp_path, with_label=True)
+ return InMemorySingleValueDataset.from_csv(csv_path)
+
+
+@pytest.fixture
+def dataset_per_token_classification_labels(dummy_data_per_token_classification_ft, tmp_path):
+ csv_path = data_to_csv(dummy_data_per_token_classification_ft, tmp_path, with_label=True)
+ return InMemoryPerTokenValueDataset.from_csv(csv_path)
+
+
+def test_in_memory_protein_dataset_length_no_labels(dataset_no_labels, dummy_protein_sequences):
+ assert len(dataset_no_labels) == len(dummy_protein_sequences)
+
+
+def test_in_memory_protein_dataset_length_with_regression_labels(
+ dataset_regression_labels, dummy_data_single_value_regression_ft
+):
+ assert len(dataset_regression_labels) == len(dummy_data_single_value_regression_ft)
+
+
+def test_in_memory_protein_dataset_length_with_class_labels(
+ dataset_per_token_classification_labels, dummy_data_per_token_classification_ft
+):
+ assert len(dataset_per_token_classification_labels) == len(dummy_data_per_token_classification_ft)
+
+
+def test_in_memory_protein_dataset_getitem_no_labels(dataset_no_labels):
+ sample = dataset_no_labels[0]
+ assert isinstance(sample, dict)
+ assert "text" in sample
+ assert "labels" in sample
+ assert isinstance(sample["text"], Tensor)
+ assert isinstance(sample["labels"], Tensor)
+
+
+def test_in_memory_protein_dataset_getitem_with_regression_labels(dataset_regression_labels):
+ assert isinstance(dataset_regression_labels, InMemoryProteinDataset)
+ sample = dataset_regression_labels[0]
+ assert isinstance(sample, dict)
+ assert "text" in sample
+ assert "labels" in sample
+ assert isinstance(sample["text"], Tensor)
+ assert isinstance(sample["labels"], Tensor)
+ assert sample["labels"].dtype == torch.float
+
+
+def test_in_memory_protein_dataset_getitem_with_class_labels(dataset_per_token_classification_labels):
+ assert isinstance(dataset_per_token_classification_labels, InMemoryProteinDataset)
+ assert isinstance(dataset_per_token_classification_labels.label_tokenizer, Label2IDTokenizer)
+ assert dataset_per_token_classification_labels.label_cls_eos_id == MLM_LOSS_IGNORE_INDEX
+
+ sample = dataset_per_token_classification_labels[0]
+ assert isinstance(sample, dict)
+ assert "text" in sample
+ assert "labels" in sample
+ assert isinstance(sample["text"], Tensor)
+ assert isinstance(sample["labels"], Tensor)
+ assert sample["labels"].dtype == torch.int64
+
+
+def test_in_memory_protein_dataset_tokenization(dataset_no_labels):
+ sample = dataset_no_labels[0]
+ tokenized_sequence = sample["text"]
+ assert isinstance(tokenized_sequence, Tensor)
+ assert tokenized_sequence.ndim == 1 # Ensure it's flattened.
+
+
+def test_transofrm_classification_label(
+ dataset_per_token_classification_labels, dummy_data_per_token_classification_ft
+):
+ pre_transfrom = dummy_data_per_token_classification_ft[0][1]
+ label_ids = torch.tensor(dataset_per_token_classification_labels.label_tokenizer.text_to_ids(pre_transfrom))
+ cls_eos = torch.tensor([dataset_per_token_classification_labels.label_cls_eos_id])
+ post_transform = torch.cat((cls_eos, label_ids, cls_eos))
+
+ assert torch.equal(dataset_per_token_classification_labels.transform_label(pre_transfrom), post_transform)
+
+
+def test_transofrm_regression_label(dataset_regression_labels):
+ """Ensure labels are transformed correctly."""
+ transformed_label = dataset_regression_labels.transform_label(1.0)
+ assert isinstance(transformed_label, Tensor)
+ assert transformed_label.dtype == torch.float
+
+
+def test_in_memory_protein_dataset_no_labels_fallback(dataset_no_labels):
+ """Ensure the dataset works even when labels are missing."""
+ sample = dataset_no_labels[0]
+ assert isinstance(sample, dict)
+ assert "labels" in sample
+ assert isinstance(sample["labels"], Tensor)
+
+
+def test_in_memory_protein_dataset_invalid_index(dataset_no_labels):
+ """Test if out-of-range index raises an error."""
+ with pytest.raises(KeyError):
+ _ = dataset_no_labels[100]
+
+
+def test_in_memory_protein_dataset_missing_sequences_column(tmp_path):
+ """Test behavior when the CSV file is empty."""
+ csv_file = tmp_path / "invalid.csv"
+ pd.DataFrame({"wrong_column": ["MKTFFS"]}).to_csv(csv_file, index=False)
+ with pytest.raises(KeyError):
+ _ = InMemoryProteinDataset.from_csv(csv_file)
+
+
+def test_in_memory_protein_dataset_special_tokens_masking(dataset_no_labels):
+ """Ensure loss mask correctly handles special tokens."""
+ sample = dataset_no_labels[0]
+ assert "loss_mask" in sample
+ assert isinstance(sample["loss_mask"], Tensor)
+ assert sample["loss_mask"].dtype == torch.bool
+
+
+def test_in_memory_protein_dataset_non_existent_file():
+ """Ensure proper error handling for missing files."""
+ with pytest.raises(FileNotFoundError):
+ InMemoryProteinDataset.from_csv("non_existent_file.csv")
diff --git a/sub-packages/bionemo-esm2/tests/bionemo/esm2/model/finetune/test_finetune.py b/sub-packages/bionemo-esm2/tests/bionemo/esm2/model/finetune/test_finetune.py
deleted file mode 100644
index 9c49d70c42..0000000000
--- a/sub-packages/bionemo-esm2/tests/bionemo/esm2/model/finetune/test_finetune.py
+++ /dev/null
@@ -1,143 +0,0 @@
-# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
-# SPDX-License-Identifier: LicenseRef-Apache2
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-
-import pytest
-from nemo.lightning import io
-
-from bionemo.core.data.load import load
-from bionemo.esm2.model.finetune.datamodule import ESM2FineTuneDataModule
-from bionemo.esm2.model.finetune.finetune_regressor import (
- ESM2FineTuneSeqConfig,
- InMemorySingleValueDataset,
-)
-from bionemo.esm2.model.finetune.finetune_token_classifier import (
- ESM2FineTuneTokenConfig,
- InMemoryPerTokenValueDataset,
-)
-from bionemo.esm2.model.finetune.peft import ESM2LoRA
-from bionemo.esm2.model.finetune.train import train_model
-from bionemo.testing import megatron_parallel_state_utils
-from bionemo.testing.callbacks import MetricTracker
-
-
-# To download a 8M internally pre-trained ESM2 model
-pretrain_ckpt_path = load("esm2/nv_8m:2.0")
-
-
-@pytest.mark.needs_gpu
-@pytest.mark.parametrize("with_peft", [True, False])
-def test_esm2_finetune_token_classifier(
- tmp_path,
- tokenizer,
- dummy_data_per_token_classification_ft,
- with_peft: bool,
- n_steps_train: int = 50,
- seed: int = 42,
-):
- if with_peft:
- pytest.xfail("FIXME PEFT fine-tuning not supported with fusions active")
-
- with megatron_parallel_state_utils.distributed_model_parallel_state(seed):
- if with_peft:
- peft = ESM2LoRA()
- else:
- peft = None
- esm2_finetune_config = ESM2FineTuneTokenConfig(initial_ckpt_path=str(pretrain_ckpt_path))
- dataset = InMemoryPerTokenValueDataset(dummy_data_per_token_classification_ft, seed=seed)
- finetune_data_module = ESM2FineTuneDataModule(dataset, dataset)
- simple_ft_checkpoint, simple_ft_metrics, trainer = train_model(
- experiment_name="finetune_new_head",
- experiment_dir=tmp_path / "finetune_new_head", # new checkpoint will land in a subdir of this
- config=esm2_finetune_config, # same config as before since we are just continuing training
- data_module=finetune_data_module,
- n_steps_train=n_steps_train,
- metric_tracker=MetricTracker(metrics_to_track_val=["loss"], metrics_to_track_train=["loss"]),
- tokenizer=tokenizer,
- peft=peft,
- _use_rich_model_summary=False,
- )
-
- weights_ckpt = simple_ft_checkpoint / "weights"
- assert weights_ckpt.exists()
- assert weights_ckpt.is_dir()
- assert io.is_distributed_ckpt(weights_ckpt)
- assert simple_ft_metrics.collection_train["loss"][0] > simple_ft_metrics.collection_train["loss"][-1]
-
- if with_peft:
- assert trainer.model.model_transform is not None
- model = trainer.model[0].module.module.module
- assert all(not p.requires_grad for p in model.embedding.parameters())
- assert all(not p.requires_grad for name, p in model.encoder.named_parameters() if "adapter" not in name)
- assert all(p.requires_grad for name, p in model.encoder.named_parameters() if "adapter" in name)
- assert all(p.requires_grad for p in model.classification_head.parameters())
- else:
- encoder_requires_grad = [
- p.requires_grad for name, p in trainer.model.named_parameters() if "classification_head" not in name
- ]
- assert not all(encoder_requires_grad), "Pretrained model is not fully frozen during fine-tuning"
-
-
-@pytest.mark.needs_gpu
-@pytest.mark.parametrize("with_peft", [True, False])
-def test_esm2_finetune_regressor(
- tmp_path,
- tokenizer,
- dummy_data_single_value_regression_ft,
- with_peft: bool,
- n_steps_train: int = 50,
- seed: int = 42,
-):
- if with_peft:
- pytest.xfail("FIXME PEFT fine-tuning not supported")
-
- with megatron_parallel_state_utils.distributed_model_parallel_state(seed):
- if with_peft:
- peft = ESM2LoRA()
- else:
- peft = None
- esm2_regression_finetune_config = ESM2FineTuneSeqConfig(initial_ckpt_path=str(pretrain_ckpt_path))
- dataset = InMemorySingleValueDataset(dummy_data_single_value_regression_ft, seed=seed)
- finetune_data_module = ESM2FineTuneDataModule(dataset, dataset)
- simple_ft_checkpoint, simple_ft_metrics, trainer = train_model(
- experiment_name="finetune_new_head_regression",
- experiment_dir=tmp_path / "finetune_new_head_regression", # new checkpoint will land in a subdir of this
- config=esm2_regression_finetune_config, # same config as before since we are just continuing training
- data_module=finetune_data_module,
- n_steps_train=n_steps_train,
- metric_tracker=MetricTracker(metrics_to_track_val=["loss"], metrics_to_track_train=["loss"]),
- tokenizer=tokenizer,
- peft=peft,
- _use_rich_model_summary=False,
- )
-
- weights_ckpt = simple_ft_checkpoint / "weights"
- assert weights_ckpt.exists()
- assert weights_ckpt.is_dir()
- assert io.is_distributed_ckpt(weights_ckpt)
- assert simple_ft_metrics.collection_train["loss"][0] > simple_ft_metrics.collection_train["loss"][-1]
-
- if with_peft:
- assert trainer.model.model_transform is not None
- model = trainer.model[0].module.module.module
- assert all(not p.requires_grad for p in model.embedding.parameters())
- assert all(not p.requires_grad for name, p in model.encoder.named_parameters() if "adapter" not in name)
- assert all(p.requires_grad for name, p in model.encoder.named_parameters() if "adapter" in name)
- assert all(p.requires_grad for p in model.regression_head.parameters())
- else:
- encoder_requires_grad = [
- p.requires_grad for name, p in trainer.model.named_parameters() if "regression_head" not in name
- ]
- assert not all(encoder_requires_grad), "Pretrained model is not fully frozen during fine-tuning"
diff --git a/sub-packages/bionemo-esm2/tests/bionemo/esm2/model/finetune/test_finetune_regressor.py b/sub-packages/bionemo-esm2/tests/bionemo/esm2/model/finetune/test_finetune_regressor.py
new file mode 100644
index 0000000000..f9ad25851c
--- /dev/null
+++ b/sub-packages/bionemo-esm2/tests/bionemo/esm2/model/finetune/test_finetune_regressor.py
@@ -0,0 +1,55 @@
+# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: LicenseRef-Apache2
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import pytest
+
+from bionemo.core.data.load import load
+from bionemo.esm2.data import tokenizer
+from bionemo.esm2.model.finetune.finetune_regressor import (
+ ESM2FineTuneSeqConfig,
+ ESM2FineTuneSeqModel,
+ MegatronMLPHead,
+)
+from bionemo.testing import megatron_parallel_state_utils
+
+
+# To download a 8M internally pre-trained ESM2 model
+pretrain_ckpt_path = load("esm2/nv_8m:2.0")
+
+
+@pytest.fixture
+def config():
+ return ESM2FineTuneSeqConfig(encoder_frozen=True, ft_dropout=0.50, initial_ckpt_path=str(pretrain_ckpt_path))
+
+
+@pytest.fixture
+def finetune_seq_model(config):
+ with megatron_parallel_state_utils.distributed_model_parallel_state():
+ model = config.configure_model(tokenizer.get_tokenizer())
+ yield model
+
+
+def test_ft_config(config):
+ assert config.initial_ckpt_skip_keys_with_these_prefixes == ["regression_head"]
+ assert config.encoder_frozen
+ assert config.ft_dropout == 0.50
+
+
+def test_ft_model_initialized(finetune_seq_model):
+ assert isinstance(finetune_seq_model, ESM2FineTuneSeqModel)
+ assert isinstance(finetune_seq_model.regression_head, MegatronMLPHead)
+ assert finetune_seq_model.post_process
+ assert not finetune_seq_model.include_embeddings_finetuning
diff --git a/sub-packages/bionemo-esm2/tests/bionemo/esm2/model/finetune/test_finetune_token_classifier.py b/sub-packages/bionemo-esm2/tests/bionemo/esm2/model/finetune/test_finetune_token_classifier.py
new file mode 100644
index 0000000000..fe4043831e
--- /dev/null
+++ b/sub-packages/bionemo-esm2/tests/bionemo/esm2/model/finetune/test_finetune_token_classifier.py
@@ -0,0 +1,57 @@
+# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: LicenseRef-Apache2
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import pytest
+
+from bionemo.core.data.load import load
+from bionemo.esm2.data import tokenizer
+from bionemo.esm2.model.finetune.finetune_token_classifier import (
+ ESM2FineTuneTokenConfig,
+ ESM2FineTuneTokenModel,
+ MegatronConvNetHead,
+)
+from bionemo.testing import megatron_parallel_state_utils
+
+
+# To download a 8M internally pre-trained ESM2 model
+pretrain_ckpt_path = load("esm2/nv_8m:2.0")
+
+
+@pytest.fixture
+def config():
+ return ESM2FineTuneTokenConfig(encoder_frozen=True, cnn_dropout=0.1, cnn_hidden_dim=32, cnn_num_classes=5)
+
+
+@pytest.fixture
+def finetune_token_model(config):
+ with megatron_parallel_state_utils.distributed_model_parallel_state():
+ model = config.configure_model(tokenizer.get_tokenizer())
+ yield model
+
+
+def test_ft_config(config):
+ assert config.initial_ckpt_skip_keys_with_these_prefixes == ["classification_head"]
+ assert config.encoder_frozen
+ assert config.cnn_dropout == 0.1
+ assert config.cnn_hidden_dim == 32
+ assert config.cnn_num_classes == 5
+
+
+def test_ft_model_initialized(finetune_token_model):
+ assert isinstance(finetune_token_model, ESM2FineTuneTokenModel)
+ assert isinstance(finetune_token_model.classification_head, MegatronConvNetHead)
+ assert finetune_token_model.post_process
+ assert not finetune_token_model.include_hiddens_finetuning
diff --git a/sub-packages/bionemo-esm2/tests/bionemo/esm2/scripts/test_finetune_esm2.py b/sub-packages/bionemo-esm2/tests/bionemo/esm2/scripts/test_finetune_esm2.py
new file mode 100644
index 0000000000..bc929a1abb
--- /dev/null
+++ b/sub-packages/bionemo-esm2/tests/bionemo/esm2/scripts/test_finetune_esm2.py
@@ -0,0 +1,315 @@
+# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: LicenseRef-Apache2
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+from pathlib import Path
+from unittest.mock import patch
+
+import pandas as pd
+import pytest
+from nemo.lightning import io
+
+from bionemo.core.data.load import load
+from bionemo.esm2.model.finetune.dataset import InMemoryPerTokenValueDataset, InMemorySingleValueDataset
+from bionemo.esm2.model.finetune.finetune_regressor import ESM2FineTuneSeqConfig
+from bionemo.esm2.model.finetune.finetune_token_classifier import ESM2FineTuneTokenConfig
+from bionemo.esm2.scripts.finetune_esm2 import finetune_esm2_entrypoint, get_parser, train_model
+from bionemo.testing import megatron_parallel_state_utils
+from bionemo.testing.callbacks import MetricTracker
+
+
+# To download a 8M internally pre-trained ESM2 model
+pretrain_ckpt_path = load("esm2/nv_8m:2.0")
+
+
+def data_to_csv(data, tmp_path):
+ """Create a mock protein dataset."""
+ csv_file = tmp_path / "protein_dataset.csv"
+ # Create a DataFrame
+ df = pd.DataFrame(data, columns=["sequences", "labels"])
+
+ # Save the DataFrame to a CSV file
+ df.to_csv(csv_file, index=False)
+ return csv_file
+
+
+@pytest.mark.needs_gpu
+def test_esm2_finetune_token_classifier(
+ tmp_path,
+ dummy_data_per_token_classification_ft,
+ n_steps_train: int = 50,
+ seed: int = 42,
+):
+ with megatron_parallel_state_utils.distributed_model_parallel_state(seed):
+ simple_ft_checkpoint, simple_ft_metrics, trainer = train_model(
+ train_data_path=data_to_csv(dummy_data_per_token_classification_ft, tmp_path),
+ valid_data_path=data_to_csv(dummy_data_per_token_classification_ft, tmp_path),
+ experiment_name="finetune_new_head_token_classification",
+ restore_from_checkpoint_path=str(pretrain_ckpt_path),
+ num_steps=n_steps_train,
+ num_nodes=1,
+ devices=1,
+ min_seq_length=None,
+ max_seq_length=1024,
+ result_dir=tmp_path / "finetune",
+ limit_val_batches=2,
+ val_check_interval=n_steps_train // 2,
+ log_every_n_steps=n_steps_train // 2,
+ num_dataset_workers=10,
+ lr=1e-5,
+ micro_batch_size=4,
+ accumulate_grad_batches=1,
+ resume_if_exists=False,
+ precision="bf16-mixed",
+ dataset_class=InMemoryPerTokenValueDataset,
+ config_class=ESM2FineTuneTokenConfig,
+ metric_tracker=MetricTracker(metrics_to_track_val=["loss"], metrics_to_track_train=["loss"]),
+ )
+
+ weights_ckpt = simple_ft_checkpoint / "weights"
+ assert weights_ckpt.exists()
+ assert weights_ckpt.is_dir()
+ assert io.is_distributed_ckpt(weights_ckpt)
+ assert simple_ft_metrics.collection_train["loss"][0] > simple_ft_metrics.collection_train["loss"][-1]
+
+ encoder_requires_grad = [
+ p.requires_grad for name, p in trainer.model.named_parameters() if "classification_head" not in name
+ ]
+ assert not all(encoder_requires_grad), "Pretrained model is not fully frozen during fine-tuning"
+
+
+@pytest.mark.needs_gpu
+def test_esm2_finetune_regressor(
+ tmp_path,
+ dummy_data_single_value_regression_ft,
+ n_steps_train: int = 50,
+ seed: int = 42,
+):
+ with megatron_parallel_state_utils.distributed_model_parallel_state(seed):
+ simple_ft_checkpoint, simple_ft_metrics, trainer = train_model(
+ train_data_path=data_to_csv(dummy_data_single_value_regression_ft, tmp_path),
+ valid_data_path=data_to_csv(dummy_data_single_value_regression_ft, tmp_path),
+ experiment_name="finetune_new_head_regression",
+ restore_from_checkpoint_path=str(pretrain_ckpt_path),
+ num_steps=n_steps_train,
+ num_nodes=1,
+ devices=1,
+ min_seq_length=None,
+ max_seq_length=1024,
+ result_dir=tmp_path / "finetune",
+ limit_val_batches=2,
+ val_check_interval=n_steps_train // 2,
+ log_every_n_steps=n_steps_train // 2,
+ num_dataset_workers=10,
+ lr=1e-5,
+ micro_batch_size=4,
+ accumulate_grad_batches=1,
+ resume_if_exists=False,
+ precision="bf16-mixed",
+ dataset_class=InMemorySingleValueDataset,
+ config_class=ESM2FineTuneSeqConfig,
+ metric_tracker=MetricTracker(metrics_to_track_val=["loss"], metrics_to_track_train=["loss"]),
+ )
+
+ weights_ckpt = simple_ft_checkpoint / "weights"
+ assert weights_ckpt.exists()
+ assert weights_ckpt.is_dir()
+ assert io.is_distributed_ckpt(weights_ckpt)
+ assert simple_ft_metrics.collection_train["loss"][0] > simple_ft_metrics.collection_train["loss"][-1]
+
+ encoder_requires_grad = [
+ p.requires_grad for name, p in trainer.model.named_parameters() if "regression_head" not in name
+ ]
+ assert not all(encoder_requires_grad), "Pretrained model is not fully frozen during fine-tuning"
+
+
+@pytest.fixture
+def mock_train_model():
+ with patch("bionemo.esm2.scripts.finetune_esm2.train_model") as mock_train:
+ yield mock_train
+
+
+@pytest.fixture
+def mock_parser_args():
+ """Fixture to create mock arguments for the parser."""
+ return [
+ "--train-data-path",
+ str(Path("train.csv")),
+ "--valid-data-path",
+ str(Path("valid.csv")),
+ "--num-gpus",
+ "1",
+ "--num-nodes",
+ "1",
+ "--min-seq-length",
+ "512",
+ "--max-seq-length",
+ "1024",
+ "--result-dir",
+ str(Path("./results")),
+ "--lr",
+ "0.001",
+ ]
+
+
+def test_finetune_esm2_entrypoint(mock_train_model, mock_parser_args):
+ """Test the finetune_esm2_entrypoint function with mocked arguments."""
+ with patch("sys.argv", ["finetune_esm2_entrypoint.py"] + mock_parser_args):
+ finetune_esm2_entrypoint()
+
+ # Check if train_model was called once
+ mock_train_model.assert_called_once()
+
+ # Check if the arguments were passed correctly
+ called_kwargs = mock_train_model.call_args.kwargs
+ assert called_kwargs["train_data_path"] == Path("train.csv")
+ assert called_kwargs["valid_data_path"] == Path("valid.csv")
+ assert called_kwargs["devices"] == 1
+ assert called_kwargs["num_nodes"] == 1
+ assert called_kwargs["min_seq_length"] == 512
+ assert called_kwargs["max_seq_length"] == 1024
+ assert called_kwargs["lr"] == 0.001
+ assert called_kwargs["result_dir"] == Path("./results")
+
+
+def test_get_parser():
+ """Test the argument parser with all possible arguments."""
+ parser = get_parser()
+ args = parser.parse_args(
+ [
+ "--train-data-path",
+ "train.csv",
+ "--valid-data-path",
+ "valid.csv",
+ "--precision",
+ "bf16-mixed",
+ "--lr",
+ "0.001",
+ "--create-tensorboard-logger",
+ "--resume-if-exists",
+ "--result-dir",
+ "./results",
+ "--experiment-name",
+ "esm2_experiment",
+ "--wandb-entity",
+ "my_team",
+ "--wandb-project",
+ "ft_project",
+ "--wandb-tags",
+ "tag1",
+ "tag2",
+ "--wandb-group",
+ "group1",
+ "--wandb-id",
+ "1234",
+ "--wandb-anonymous",
+ "--wandb-log-model",
+ "--wandb-offline",
+ "--num-gpus",
+ "2",
+ "--num-nodes",
+ "1",
+ "--num-steps",
+ "1000",
+ "--num-dataset-workers",
+ "4",
+ "--val-check-interval",
+ "500",
+ "--log-every-n-steps",
+ "100",
+ "--min-seq-length",
+ "512",
+ "--max-seq-length",
+ "1024",
+ "--limit-val-batches",
+ "2",
+ "--micro-batch-size",
+ "32",
+ "--pipeline-model-parallel-size",
+ "2",
+ "--tensor-model-parallel-size",
+ "2",
+ "--accumulate-grad-batches",
+ "2",
+ "--save-last-checkpoint",
+ "--metric-to-monitor-for-checkpoints",
+ "val_loss",
+ "--save-top-k",
+ "5",
+ "--restore-from-checkpoint-path",
+ "./checkpoint",
+ "--nsys-profiling",
+ "--nsys-start-step",
+ "10",
+ "--nsys-end-step",
+ "50",
+ "--nsys-ranks",
+ "0",
+ "1",
+ "--no-overlap-grad-reduce",
+ "--no-overlap-param-gather",
+ "--no-average-in-collective",
+ "--grad-reduce-in-fp32",
+ "--dataset-class",
+ "InMemoryPerTokenValueDataset",
+ "--config-class",
+ "ESM2FineTuneTokenConfig",
+ ]
+ )
+
+ # Assertions for all arguments
+ assert args.train_data_path == Path("train.csv")
+ assert args.valid_data_path == Path("valid.csv")
+ assert args.precision == "bf16-mixed"
+ assert args.lr == 0.001
+ assert args.create_tensorboard_logger is True
+ assert args.resume_if_exists is True
+ assert args.result_dir == Path("./results")
+ assert args.experiment_name == "esm2_experiment"
+ assert args.wandb_entity == "my_team"
+ assert args.wandb_project == "ft_project"
+ assert args.wandb_tags == ["tag1", "tag2"]
+ assert args.wandb_group == "group1"
+ assert args.wandb_id == "1234"
+ assert args.wandb_anonymous is True
+ assert args.wandb_log_model is True
+ assert args.wandb_offline is True
+ assert args.num_gpus == 2
+ assert args.num_nodes == 1
+ assert args.num_steps == 1000
+ assert args.num_dataset_workers == 4
+ assert args.val_check_interval == 500
+ assert args.log_every_n_steps == 100
+ assert args.min_seq_length == 512
+ assert args.max_seq_length == 1024
+ assert args.limit_val_batches == 2
+ assert args.micro_batch_size == 32
+ assert args.pipeline_model_parallel_size == 2
+ assert args.tensor_model_parallel_size == 2
+ assert args.accumulate_grad_batches == 2
+ assert args.save_last_checkpoint is True
+ assert args.metric_to_monitor_for_checkpoints == "val_loss"
+ assert args.save_top_k == 5
+ assert args.restore_from_checkpoint_path == Path("./checkpoint")
+ assert args.nsys_profiling is True
+ assert args.nsys_start_step == 10
+ assert args.nsys_end_step == 50
+ assert args.nsys_ranks == [0, 1]
+ assert args.no_overlap_grad_reduce is True
+ assert args.no_overlap_param_gather is True
+ assert args.no_average_in_collective is True
+ assert args.grad_reduce_in_fp32 is True
+ assert args.dataset_class == InMemoryPerTokenValueDataset
+ assert args.config_class == ESM2FineTuneTokenConfig
diff --git a/sub-packages/bionemo-esm2/tests/bionemo/esm2/scripts/test_infer_esm2.py b/sub-packages/bionemo-esm2/tests/bionemo/esm2/scripts/test_infer_esm2.py
index e601ce18ed..9575fafa9e 100644
--- a/sub-packages/bionemo-esm2/tests/bionemo/esm2/scripts/test_infer_esm2.py
+++ b/sub-packages/bionemo-esm2/tests/bionemo/esm2/scripts/test_infer_esm2.py
@@ -19,45 +19,17 @@
import pandas as pd
import pytest
import torch
-from torch.utils.data import DataLoader
from bionemo.core.data.load import load
from bionemo.core.utils.dtypes import get_autocast_dtype
from bionemo.esm2.api import ESM2Config
from bionemo.esm2.data.tokenizer import get_tokenizer
-from bionemo.esm2.model.finetune.datamodule import ESM2FineTuneDataModule, InMemoryCSVDataset
from bionemo.esm2.scripts.infer_esm2 import infer_model
from bionemo.llm.data import collate
from bionemo.llm.lightning import batch_collator
from bionemo.llm.utils.callbacks import IntervalT
-# Function to check GPU memory
-def check_gpu_memory(threshold_gb):
- if torch.cuda.is_available():
- gpu_memory = torch.cuda.get_device_properties(0).total_memory / (1024**3) # Memory in GB
- return gpu_memory < threshold_gb
- return False
-
-
-@pytest.fixture
-def dummy_protein_sequences():
- """Create a list of artificial protein sequences"""
- artificial_sequence_data = [
- "TLILGWSDKLGSLLNQLAIANESLGGGTIAVMAERDKEDMELDIGKMEFDFKGTSVI",
- "LYSGDHSTQGARFLRDLAENTGRAEYELLSLF",
- "GRFNVWLGGNESKIRQVLKAVKEIGVSPTLFAVYEKN",
- "DELTALGGLLHDIGKPVQRAGLYSGDHSTQGARFLRDLAENTGRAEYELLSLF",
- "KLGSLLNQLAIANESLGGGTIAVMAERDKEDMELDIGKMEFDFKGTSVI",
- "LFGAIGNAISAIHGQSAVEELVDAFVGGARISSAFPYSGDTYYLPKP",
- "LGGLLHDIGKPVQRAGLYSGDHSTQGARFLRDLAENTGRAEYELLSLF",
- "LYSGDHSTQGARFLRDLAENTGRAEYELLSLF",
- "ISAIHGQSAVEELVDAFVGGARISSAFPYSGDTYYLPKP",
- "SGSKASSDSQDANQCCTSCEDNAPATSYCVECSEPLCETCVEAHQRVKYTKDHTVRSTGPAKT",
- ]
- return artificial_sequence_data
-
-
@pytest.fixture
def dummy_protein_csv(tmp_path, dummy_protein_sequences):
"""Create a mock protein dataset."""
@@ -70,16 +42,6 @@ def dummy_protein_csv(tmp_path, dummy_protein_sequences):
return csv_file
-@pytest.fixture
-def dataset(dummy_protein_csv):
- return InMemoryCSVDataset(dummy_protein_csv)
-
-
-@pytest.fixture
-def data_module(dataset):
- return ESM2FineTuneDataModule(predict_dataset=dataset)
-
-
@pytest.fixture
def padded_tokenized_sequences(dummy_protein_sequences):
tokenizer = get_tokenizer()
@@ -91,49 +53,6 @@ def padded_tokenized_sequences(dummy_protein_sequences):
return collated_batch["text"]
-def test_in_memory_csv_dataset(dataset):
- assert len(dataset) > 0
- sample = dataset[0]
- assert isinstance(sample, dict)
- assert "text" in sample
- assert "labels" in sample
-
-
-def test_in_memory_csv_dataset_load_data(dataset, dummy_protein_csv):
- sequences, labels = dataset.load_data(dummy_protein_csv)
- assert isinstance(sequences, list)
- assert isinstance(labels, list)
-
-
-def test_esm2_fine_tune_data_module_init(data_module):
- assert data_module.train_dataset is None
- assert data_module.valid_dataset is None
- assert data_module.predict_dataset is not None
-
-
-def test_esm2_fine_tune_data_module_predict_dataloader(data_module):
- predict_dataloader = data_module.predict_dataloader()
- assert isinstance(predict_dataloader, DataLoader)
- batch = next(iter(predict_dataloader))
- assert isinstance(batch, dict)
- assert "text" in batch
-
-
-def test_esm2_fine_tune_data_module_setup(data_module):
- with pytest.raises(RuntimeError):
- data_module.setup("fit")
-
-
-def test_esm2_fine_tune_data_module_train_dataloader(data_module):
- with pytest.raises(AttributeError):
- data_module.train_dataloader()
-
-
-def test_esm2_fine_tune_data_module_val_dataloader(data_module):
- with pytest.raises(AttributeError):
- data_module.val_dataloader()
-
-
@pytest.mark.parametrize("precision", ["fp32", "bf16-mixed"])
@pytest.mark.parametrize("prediction_interval", get_args(IntervalT))
def test_infer_runs(