Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

RowFeatureIndex Optimization #531

Merged
merged 13 commits into from
Dec 17, 2024
Original file line number Diff line number Diff line change
Expand Up @@ -14,16 +14,11 @@
# limitations under the License.

import glob
import os
import shlex
import subprocess
from pathlib import Path
from typing import get_args

import pandas as pd
import pytest
import torch
from lightning.fabric.plugins.environments.lightning import find_free_network_port
from torch.utils.data import DataLoader

from bionemo.core.data.load import load
Expand Down Expand Up @@ -147,7 +142,12 @@ def test_esm2_fine_tune_data_module_val_dataloader(data_module):
@pytest.mark.parametrize("prediction_interval", get_args(IntervalT))
@pytest.mark.skipif(check_gpu_memory(30), reason="Skipping test due to insufficient GPU memory")
def test_infer_runs(
tmpdir, dummy_protein_csv, dummy_protein_sequences, precision, padded_tokenized_sequences, prediction_interval
tmpdir,
dummy_protein_csv,
dummy_protein_sequences,
precision,
prediction_interval,
padded_tokenized_sequences,
):
data_path = dummy_protein_csv
result_dir = tmpdir / "results"
Expand Down Expand Up @@ -188,35 +188,9 @@ def test_infer_runs(
# token_logits are [sequence, batch, num_tokens]
assert results["token_logits"].shape[:-1] == (min_seq_len, len(dummy_protein_sequences))


@pytest.mark.skipif(check_gpu_memory(40), reason="Skipping test due to insufficient GPU memory")
@pytest.mark.parametrize("checkpoint_path", [esm2_3b_checkpoint_path, esm2_650m_checkpoint_path])
def test_infer_cli(tmpdir, dummy_protein_csv, checkpoint_path):
# Clear the GPU cache before starting the test
torch.cuda.empty_cache()

result_dir = Path(tmpdir.mkdir("results"))
results_path = result_dir / "esm2_infer_results.pt"
open_port = find_free_network_port()
env = dict(**os.environ)
env["MASTER_PORT"] = str(open_port)

cmd_str = f"""infer_esm2 \
--checkpoint-path {checkpoint_path} \
--data-path {dummy_protein_csv} \
--results-path {results_path} \
--precision bf16-mixed \
--include-hiddens \
--include-embeddings \
--include-logits \
--include-input-ids
""".strip()

cmd = shlex.split(cmd_str)
result = subprocess.run(
cmd,
cwd=tmpdir,
env=env,
capture_output=True,
)
assert result.returncode == 0, f"Failed with: {cmd_str}"
# test 1:1 mapping between input sequence and results
# this does not apply to "batch" prediction_interval mode since the order of batches may not be consistent
# due distributed processing. To address this, we optionally include input_ids in the predictions, allowing
# for accurate mapping post-inference.
if prediction_interval == "epoch":
assert torch.equal(padded_tokenized_sequences, results["input_ids"])
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,20 @@
__all__: Sequence[str] = ("RowFeatureIndex",)


def are_dicts_equal(dict1: dict[str, np.ndarray], dict2: dict[str, np.ndarray]) -> bool:
"""Compare two dictionaries with string keys and numpy.ndarray values.

Args:
dict1 (dict[str, np.ndarray]): The first dictionary to compare.
dict2 (dict[str, np.ndarray]): The second dictionary to compare.

Returns:
bool: True if the dictionaries have the same keys and all corresponding
numpy arrays are equal; False otherwise.
"""
return dict1.keys() == dict2.keys() and all(np.array_equal(dict1[k], dict2[k]) for k in dict1)


class RowFeatureIndex:
"""Maintains a mapping between a row and its features.

Expand Down Expand Up @@ -100,10 +114,16 @@ def append_features(
if isinstance(features, pd.DataFrame):
raise TypeError("Expected a dictionary, but received a Pandas DataFrame.")
csum = max(self._cumulative_sum_index[-1], 0)
self._cumulative_sum_index = np.append(self._cumulative_sum_index, csum + n_obs)
self._feature_arr.append(features)
self._num_genes_per_row.append(num_genes)
self._labels.append(label)

# If the new feature array is identical to the last one, it is not appended. Instead, the last array accounts
# for the additional n_obs also.
if len(self._feature_arr) > 0 and are_dicts_equal(self._feature_arr[-1], features):
self._cumulative_sum_index[-1] = csum + n_obs
else:
self._cumulative_sum_index = np.append(self._cumulative_sum_index, csum + n_obs)
self._feature_arr.append(features)
self._num_genes_per_row.append(num_genes)
self._labels.append(label)

def lookup(self, row: int, select_features: Optional[list[str]] = None) -> Tuple[list[np.ndarray], str]:
"""Find the features at a given row.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,32 @@
import pandas as pd
import pytest

from bionemo.scdl.index.row_feature_index import RowFeatureIndex
from bionemo.scdl.index.row_feature_index import RowFeatureIndex, are_dicts_equal


def test_equal_dicts():
dict1 = {"a": np.array([1, 2, 3]), "b": np.array([4, 5, 6])}
dict2 = {"a": np.array([1, 2, 3]), "b": np.array([4, 5, 6])}
assert are_dicts_equal(dict1, dict2) is True


def test_unequal_values():
dict1 = {"a": np.array([1, 2, 3]), "b": np.array([4, 5, 6])}
dict3 = {"a": np.array([1, 2, 3]), "b": np.array([7, 8, 9])}

assert are_dicts_equal(dict1, dict3) is False


def test_unequal_keys():
dict1 = {"a": np.array([1, 2, 3]), "b": np.array([4, 5, 6])}
dict4 = {"a": np.array([1, 2, 3]), "c": np.array([4, 5, 6])}
assert are_dicts_equal(dict1, dict4) is False


def test_different_lengths():
dict1 = {"a": np.array([1, 2, 3]), "b": np.array([4, 5, 6])}
smaller_dict = {"a": np.array([1, 2, 3])}
assert are_dicts_equal(dict1, smaller_dict) is False


@pytest.fixture
Expand All @@ -37,6 +62,20 @@ def create_first_RowFeatureIndex() -> RowFeatureIndex:
return index


@pytest.fixture
def create_same_features_first_RowFeatureIndex() -> RowFeatureIndex:
"""
Instantiate a RowFeatureIndex.

Returns:
A RowFeatureIndex with known values.
"""
one_feats = {"feature_name": np.array(["FF", "GG", "HH"]), "feature_int": np.array([1, 2, 3])}
index = RowFeatureIndex()
index.append_features(6, one_feats, len(one_feats["feature_name"]))
return index


@pytest.fixture
def create_second_RowFeatureIndex() -> RowFeatureIndex:
"""
Expand Down Expand Up @@ -86,14 +125,17 @@ def test_feature_index_internals_on_single_index(create_first_RowFeatureIndex):
assert len(vals) == 1


def test_feature_index_internals_on_append(create_first_RowFeatureIndex):
def test_feature_index_internals_on_append_different_features(
create_first_RowFeatureIndex, create_second_RowFeatureIndex
):
one_feats = {"feature_name": np.array(["FF", "GG", "HH"]), "feature_int": np.array([1, 2, 3])}
two_feats = {
"feature_name": np.array(["FF", "GG", "HH", "II", "ZZ"]),
"gene_name": np.array(["RET", "NTRK", "PPARG", "TSHR", "EGFR"]),
"spare": np.array([None, None, None, None, None]),
}
create_first_RowFeatureIndex.append_features(8, two_feats, len(two_feats["feature_name"]), "MY_DATAFRAME")
create_first_RowFeatureIndex.concat(create_second_RowFeatureIndex)
# append(8, two_feats, len(two_feats["feature_name"]), "MY_DATAFRAME")
assert len(create_first_RowFeatureIndex) == 2
assert create_first_RowFeatureIndex.number_vars_at_row(1) == 3
assert create_first_RowFeatureIndex.number_vars_at_row(13) == 5
Expand All @@ -113,6 +155,28 @@ def test_feature_index_internals_on_append(create_first_RowFeatureIndex):
assert label == "MY_DATAFRAME"


def test_feature_index_internals_on_append_same_features(create_first_RowFeatureIndex):
one_feats = {"feature_name": np.array(["FF", "GG", "HH"]), "feature_int": np.array([1, 2, 3])}
create_first_RowFeatureIndex.concat(create_first_RowFeatureIndex)
# append(8, two_feats, len(two_feats["feature_name"]), "MY_DATAFRAME")
assert len(create_first_RowFeatureIndex) == 1
assert create_first_RowFeatureIndex.number_vars_at_row(1) == 3
assert create_first_RowFeatureIndex.number_vars_at_row(13) == 3
assert create_first_RowFeatureIndex.number_vars_at_row(19) == 3
assert create_first_RowFeatureIndex.number_vars_at_row(2) == 3
polinabinder1 marked this conversation as resolved.
Show resolved Hide resolved
assert sum(create_first_RowFeatureIndex.number_of_values()) == 2 * (12 * 3)
assert create_first_RowFeatureIndex.number_of_values()[0] == 2 * (12 * 3)
assert create_first_RowFeatureIndex.number_of_rows() == 24
feats, label = create_first_RowFeatureIndex.lookup(row=3, select_features=None)
assert np.all(feats[0] == one_feats["feature_name"])
assert np.all(feats[1] == one_feats["feature_int"])
assert label is None
feats, label = create_first_RowFeatureIndex.lookup(row=15, select_features=None)
assert np.all(feats[0] == one_feats["feature_name"])
assert np.all(feats[1] == one_feats["feature_int"])
assert label is None


def test_concat_length(
create_first_RowFeatureIndex,
create_second_RowFeatureIndex,
Expand Down
Loading