Skip to content

Commit

Permalink
don't eagerly download esm2 checkpoints (NVIDIA#567)
Browse files Browse the repository at this point in the history
Adding the load calls in the global namespace leads to these checkpoints
being downloaded even when we don't use them. We really should call
`load` as lazily as possible to avoid materializing data for tests we
don't need.

Signed-off-by: Peter St. John <pstjohn@nvidia.com>
  • Loading branch information
pstjohn authored Jan 3, 2025
1 parent d901ebd commit f12a475
Showing 1 changed file with 1 addition and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,6 @@
from bionemo.llm.utils.callbacks import IntervalT


esm2_650m_checkpoint_path = load("esm2/650m:2.0")
esm2_3b_checkpoint_path = load("esm2/3b:2.0", source="ngc")


# Function to check GPU memory
def check_gpu_memory(threshold_gb):
if torch.cuda.is_available():
Expand Down Expand Up @@ -140,7 +136,6 @@ def test_esm2_fine_tune_data_module_val_dataloader(data_module):

@pytest.mark.parametrize("precision", ["fp32", "bf16-mixed"])
@pytest.mark.parametrize("prediction_interval", get_args(IntervalT))
@pytest.mark.skipif(check_gpu_memory(30), reason="Skipping test due to insufficient GPU memory")
def test_infer_runs(
tmpdir,
dummy_protein_csv,
Expand All @@ -155,7 +150,7 @@ def test_infer_runs(

infer_model(
data_path=data_path,
checkpoint_path=esm2_650m_checkpoint_path,
checkpoint_path=load("esm2/650m:2.0"),
results_path=result_dir,
min_seq_length=min_seq_len,
prediction_interval=prediction_interval,
Expand Down

0 comments on commit f12a475

Please sign in to comment.