Skip to content

Commit

Permalink
set min seq len by default (#621)
Browse files Browse the repository at this point in the history
### Description
In https://nvbugspro.nvidia.com/bug/5060664 they notice a warning
message about performance when pretraining with variable sequence
lengths. This is largely an oversight since our test scripts didn't set
both minimum and maximum seq_lens. We should have the default if
min_seq_length is omitted be to just pad to the maximum sequence length
for performance reasons.

### Type of changes
<!-- Mark the relevant option with an [x] -->

- [x]  Bug fix (non-breaking change which fixes an issue)
- [ ]  New feature (non-breaking change which adds functionality)
- [ ]  Refactor
- [ ]  Documentation update
- [ ]  Other (please describe):

### CI Pipeline Configuration
Configure CI behavior by applying the relevant labels:

-
[SKIP_CI](https://github.com/NVIDIA/bionemo-framework/blob/main/docs/docs/user-guide/contributing/contributing.md#skip_ci)
- Skip all continuous integration tests
-
[INCLUDE_NOTEBOOKS_TESTS](https://github.com/NVIDIA/bionemo-framework/blob/main/docs/docs/user-guide/contributing/contributing.md#include_notebooks_tests)
- Execute notebook validation tests in pytest

> [!NOTE]
> By default, the notebooks validation tests are skipped unless
explicitly enabled.

### Usage
<!--- How does a user interact with the changed code -->
```python
TODO: Add code snippet
```

### Pre-submit Checklist
<!--- Ensure all items are completed before submitting -->

 - [x] I have tested these changes locally
 - [x] I have updated the documentation accordingly
 - [x] I have added/updated tests as needed
 - [x] All existing tests pass successfully

Signed-off-by: Peter St. John <pstjohn@nvidia.com>
  • Loading branch information
pstjohn authored Jan 18, 2025
1 parent 7f9dd97 commit 0c990a7
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 3 deletions.
6 changes: 3 additions & 3 deletions sub-packages/bionemo-esm2/src/bionemo/esm2/data/datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,8 +65,8 @@ def __init__(
valid_cluster_path: A path to the parquet files containing UniRef50 validation clusters.
valid_database_path: A path to the sqlite file mapping UniRef50 cluster IDs to sequences.
seed: Input random seed. If None, initializes randomly. Defaults to 42.
min_seq_length: Whether to pad sequences to a minimum length. If None, no extra padding is added. Defaults
to None.
min_seq_length: Whether to pad sequences to a minimum length. If None, sequences are padded to the maximum
sequence length. Defaults to None.
max_seq_length: The maximum context length for the ESM transformer. Defaults to 1024.
micro_batch_size: Passed to MegatronDataSampler. Defaults to 4.
global_batch_size: Passed to MegatronDataSampler.. Defaults to 8.
Expand All @@ -87,7 +87,7 @@ def __init__(
self._valid_cluster_path = valid_cluster_path
self._valid_database_path = valid_database_path
self._seed = seed
self._min_seq_length = min_seq_length
self._min_seq_length = min_seq_length if min_seq_length is not None else max_seq_length
self._max_seq_length = max_seq_length
self._mask_prob = mask_prob
self._mask_token_prob = mask_token_prob
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,21 @@ def test_create_esm_datamodule_raises_without_trainer(dummy_protein_dataset, dum
data_module.setup()


def test_esm_datamodule_sets_min_seq_len_to_max_seq_len(dummy_protein_dataset, dummy_parquet_train_val_inputs):
train_cluster_path, valid_cluster_path = dummy_parquet_train_val_inputs

# Initialize the data module.
data_module = ESMDataModule(
train_cluster_path=train_cluster_path,
train_database_path=dummy_protein_dataset,
valid_cluster_path=valid_cluster_path,
valid_database_path=dummy_protein_dataset,
max_seq_length=36,
)

assert data_module._min_seq_length == 36


def test_create_esm_datamodule_raises_without_trainer_max_steps(dummy_protein_dataset, dummy_parquet_train_val_inputs):
train_cluster_path, valid_cluster_path = dummy_parquet_train_val_inputs

Expand Down

0 comments on commit 0c990a7

Please sign in to comment.