Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fully Integrate SCDL into Geneformer #480

Merged
merged 32 commits into from
Dec 20, 2024
Merged
Show file tree
Hide file tree
Changes from 25 commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
08ef851
Change RowFeatureIndex and RowFeatureIndex tests to use a list of dic…
Nov 20, 2024
c9ff683
Update load_h5ad to append features in dict format to to the row feat…
Nov 20, 2024
5f86da4
Modify Single Cell Memmap Dataset unit tests to reflect changes
Nov 20, 2024
896bad0
remove conversion to np.array in get_row for now
Nov 20, 2024
eb17845
Convert values and col indices to np array so that we're not returnin…
Nov 21, 2024
1663903
Revert conversion to np array, and refactor num_vars_at_row to use in…
Nov 22, 2024
0497c98
Merge branch 'main' into savitha/scdl-performance-improvements
savitha-eng Nov 22, 2024
7a43706
Made changes requested in review.
Nov 25, 2024
da395b4
Merge branch 'savitha/scdl-performance-improvements' of github.com:NV…
Nov 25, 2024
9e11ab8
Integrate SCDL into Geneformer, rebased on the latest changes in main
Nov 26, 2024
37de5d1
Merge branch 'main' into savitha/integrate-scdl-geneformer-rebased
savitha-eng Nov 26, 2024
546f84e
Tests for Geneformer SingleCellDataset
Nov 26, 2024
0dcc56b
Merge branch 'savitha/integrate-scdl-geneformer-rebased' of github.co…
Nov 26, 2024
9d4c6a4
Data directory fixtures needed for pytest
Nov 26, 2024
eea6b42
Add bypass_tokenize_vocab to the arguments for this script
Nov 26, 2024
e642bc9
Changes to Inference tutorial notebook to support SCDL integrated Gen…
Dec 2, 2024
d755901
modify dataset dir creation
Dec 2, 2024
507e31b
all scdl integration changes
Dec 2, 2024
f6d9380
Merge branch 'main' into savitha/integrate-scdl-geneformer-rebased
savitha-eng Dec 2, 2024
9846894
Updated documentation, removed refs to sc_memmap, & made changes requ…
Dec 4, 2024
6afed04
Merge branch 'main' into savitha/integrate-scdl-geneformer-rebased
savitha-eng Dec 4, 2024
2e18cbb
Merge branch 'main' into savitha/integrate-scdl-geneformer-rebased
savitha-eng Dec 5, 2024
dc093e2
Merge branch 'main' into savitha/integrate-scdl-geneformer-rebased
polinabinder1 Dec 11, 2024
13ec59f
Make CLI argument for checking token vocab more understandable
Dec 13, 2024
1e9b58a
merge main
polinabinder1 Dec 13, 2024
452dc6c
merge main
polinabinder1 Dec 18, 2024
7bdddb0
adding fixed length
polinabinder1 Dec 18, 2024
71bddc3
notebook updates
polinabinder1 Dec 18, 2024
fdfc18c
adding correct notebook
polinabinder1 Dec 19, 2024
40e34cc
Merge branch 'main' into savitha/integrate-scdl-geneformer-rebased
polinabinder1 Dec 19, 2024
1cbbf05
Merge branch 'main' into savitha/integrate-scdl-geneformer-rebased
polinabinder1 Dec 20, 2024
696d42c
test case fixes
polinabinder1 Dec 20, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
docs/site/
*.nemo
protein/
singlecell/
results/

# Local configs
Expand Down
16 changes: 8 additions & 8 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -279,10 +279,10 @@ type, and then pass in the config type to the training recipe.
Similar to ESM-2, you can download the dataset and checkpoint through our utility function.

```bash
TEST_DATA_DIR=$(download_bionemo_data single_cell/testdata-20240506 --source $MY_DATA_SOURCE); \
TEST_DATA_DIR=$(download_bionemo_data single_cell/testdata-20241203 --source $MY_DATA_SOURCE); \
GENEFORMER_10M_CKPT=$(download_bionemo_data geneformer/10M_240530:2.0 --source $MY_DATA_SOURCE); \
train_geneformer \
--data-dir ${TEST_DATA_DIR}/cellxgene_2023-12-15_small/processed_data \
--data-dir ${TEST_DATA_DIR}/cellxgene_2023-12-15_small_processed_scdl \
--result-dir ./results \
--restore-from-checkpoint-path ${GENEFORMER_10M_CKPT} \
--experiment-name test_experiment \
Expand All @@ -305,9 +305,9 @@ copy the `sub-projects/bionemo-geneformer/geneformer/scripts/train_geneformer.py
Simple fine-tuning example (**NOTE**: please change `--restore-from-checkpoint-path` to be the checkpoint directory path that was output last
by the previous train run)
```bash
TEST_DATA_DIR=$(download_bionemo_data single_cell/testdata-20240506 --source $MY_DATA_SOURCE); \
TEST_DATA_DIR=$(download_bionemo_data single_cell/testdata-20241203 --source $MY_DATA_SOURCE); \
train_geneformer \
--data-dir ${TEST_DATA_DIR}/cellxgene_2023-12-15_small/processed_data \
--data-dir ${TEST_DATA_DIR}/cellxgene_2023-12-15_small_processed_scdl \
--result-dir ./results \
--experiment-name test_finettune_experiment \
--num-gpus 1 \
Expand All @@ -331,11 +331,11 @@ customizations for your task.


```bash
TEST_DATA_DIR=$(download_bionemo_data single_cell/testdata-20240506 --source $MY_DATA_SOURCE); \
TEST_DATA_DIR=$(download_bionemo_data single_cell/testdata-20241203 --source $MY_DATA_SOURCE); \
bionemo-geneformer-recipe \
--recipe geneformer_10m_pretrain_recipe \
--dest my_config.yaml \
--data-path ${TEST_DATA_DIR}/cellxgene_2023-12-15_small/processed_data \
--recipe 10m-pretrain \
--dest my_config.json \
--data-path ${TEST_DATA_DIR}/cellxgene_2023-12-15_small_processed_scdl \
--result-dir ./results
```
> ⚠️ **IMPORTANT:** Inspect and edit the contents of the outputted my_config.yaml as you see fit
Expand Down
savitha-eng marked this conversation as resolved.
Show resolved Hide resolved

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,11 @@
sha256: 7a4237537bf535dfa00301ce8cc7073e0a23d5bc8aa902ad65db9f51b57a6df9 # pragma: allowlist secret
owner: Polina Binder <pbinder@nvidia.com>
description: Sample test data for SCDL.

- tag: sample_scdl_feature_ids
ngc: nvidia/clara/scdl_sample_test_feature_ids:1.0
ngc_registry: resource
pbss: s3://bionemo-ci/test-data/scdl_sample_test_feat_ids.tar.gz
sha256: 9020ba336dbfe33bddadba26ca0cde49958cbd73c5ad44f0960a5a4837c9db26 # pragma: allowlist secret
owner: Savitha Srinivasan <savithas@nvidia.com>
description: Sample test data for SCDL with feature IDs appended.
Original file line number Diff line number Diff line change
Expand Up @@ -21,3 +21,11 @@
sha256: ab038b184de52e53ff7bcea5e01d97d55944c507db88c0495bdf9e5e9e0303a4 # pragma: allowlist secret
owner: John St John <jstjohn@nvidia.com>
description: Golden values for geneformer QA model.

- tag: testdata-20241203
ngc: nvidia/clara/singlecell-testdata:2.0
ngc_registry: resource
pbss: "s3://bionemo-ci/test-data/singlecell/singlecell-scdltestdata-20241203.tar.gz"
sha256: d8e3ea569bc43768c24aa651aff77722df202078415528497c22394046b08cc3 # pragma: allowlist secret
owner: Savitha Srinivasan <savithas@nvidia.com>
description: Test data for single cell models in SCDL Memmap format.
2 changes: 1 addition & 1 deletion sub-packages/bionemo-geneformer/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ pytest -v .


## Acquiring Data
Datasets are expected to be in the form of AnnData (.h5ad) objects such as those downloaded from [Cell x Gene | CZI](https://chanzuckerberg.github.io/cellxgene-census/). They are then pre-processed with either `bionemo-geneformer/src/bionemo/geneformer/data/singlecell/sc_memmap.py` or with sc-DL.
Datasets are expected to be in the form of AnnData (.h5ad) objects such as those downloaded from [Cell x Gene | CZI](https://chanzuckerberg.github.io/cellxgene-census/). They are then pre-processed with `sub-packages/bionemo-scdl/src/bionemo/scdl/scripts/convert_h5ad_to_scdl.py`.

## Geneformer-nv 10M and 106M
Refer to the Dataset cards and Model cards to learn more about the pre-trained checkpoints provided for both 10M and 106M of Geneformer-nv.
Expand Down
1 change: 0 additions & 1 deletion sub-packages/bionemo-geneformer/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@ dependencies = [
[project.scripts]
bionemo-geneformer-train= "bionemo.geneformer.run.main:main"
bionemo-geneformer-recipe= "bionemo.geneformer.run.recipes:main"
sc_memmap = "bionemo.geneformer.scripts.sc_memmap:main_cli"
infer_geneformer = "bionemo.geneformer.scripts.infer_geneformer:geneformer_infer_entrypoint"
train_geneformer = "bionemo.geneformer.scripts.train_geneformer:entrypoint"
geneformer_mlm_loss_eval = "bionemo.geneformer.scripts.geneformer_mlm_loss_eval:entrypoint"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,7 @@ def main(
seq_len_nv: int = 2048,
seq_len_hf: int = 2048,
seed: int = 513,
include_unrecognized_vocab_in_dataset: bool = False,
):
"""Inference function (requires DDP and only training data that fits in memory)."""
# This is just used to get the tokenizer :(
Expand Down Expand Up @@ -185,6 +186,7 @@ def main(
max_len=seq_len_nv,
mask_prob=mask_prob,
seed=seed,
include_unrecognized_vocab_in_dataset=include_unrecognized_vocab_in_dataset,
)
ds_hf_nvfilt = SingleCellDataset(
dataset_path,
Expand All @@ -194,6 +196,7 @@ def main(
mask_prob=mask_prob,
eos_token=hf_tokenizer.token_to_id(hf_tokenizer.sep_token), # Stored in the special token
seed=seed,
include_unrecognized_vocab_in_dataset=include_unrecognized_vocab_in_dataset,
)
print(f"Loaded dataset of length (NV): {len(ds_nv)}, (HF): {len(ds_hf_nvfilt)}")

Expand Down Expand Up @@ -299,6 +302,11 @@ def entrypoint():
)
parser.add_argument("--hf-model-path", type=str, default="ctheodoris/Geneformer", help="HF model path")
parser.add_argument("--dataset-path", type=Path, help="Path to dataset directory", required=True)
parser.add_argument(
"--include-unrecognized-vocab-in-dataset",
action="store_true",
help="If set to true, a hard-check is performed to verify all gene identifers are in the user supplied tokenizer vocab. Defaults to false which means any gene identifier not in the user supplied tokenizer vocab will be excluded.",
)

args = parser.parse_args()
main(
Expand All @@ -307,6 +315,7 @@ def entrypoint():
args.dataset_path,
args.hf_token_dictionary_path,
args.hf_medians_dictionary_path,
args.include_unrecognized_vocab_in_dataset,
)


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ class SingleCellDataModule(MegatronDataModule):
num_mask_per_sample (int): Number of masked versions of a single sample to be returned by each worker
train_batch_size (int): Batch size for training
val_batch_size (int): Batch size for validation
include_unrecognized_vocab_in_dataset (bool, optional): If set to True, a hard-check is performed to verify all gene identifers are in the user supplied tokenizer vocab. Defaults to False which means any gene identifier not in the user supplied tokenizer vocab will be excluded.

Attributes:
cfg (Config): Configuration object
Expand Down Expand Up @@ -82,6 +83,7 @@ def __init__( # noqa: D107
num_workers: int = 10, # TODO can this be automatically set?
persistent_workers: bool = True,
pin_memory: bool = True,
include_unrecognized_vocab_in_dataset: bool = False,
) -> None:
super().__init__()
if predict_dataset_path is None:
Expand Down Expand Up @@ -122,6 +124,7 @@ def __init__( # noqa: D107
mask_token_prob=self.mask_token_prob,
random_token_prob=self.random_token_prob,
seed=random_utils.get_seed_from_rng(rng),
include_unrecognized_vocab_in_dataset=include_unrecognized_vocab_in_dataset,
)
self._val_dataset_ori = SingleCellDataset(
self.data_path_val,
Expand All @@ -132,6 +135,7 @@ def __init__( # noqa: D107
mask_token_prob=self.mask_token_prob,
random_token_prob=self.random_token_prob,
seed=random_utils.get_seed_from_rng(rng),
include_unrecognized_vocab_in_dataset=include_unrecognized_vocab_in_dataset,
)
self._test_dataset_ori = SingleCellDataset(
self.data_path_test,
Expand All @@ -142,6 +146,7 @@ def __init__( # noqa: D107
mask_token_prob=self.mask_token_prob,
random_token_prob=self.random_token_prob,
seed=random_utils.get_seed_from_rng(rng),
include_unrecognized_vocab_in_dataset=include_unrecognized_vocab_in_dataset,
)
self._predict_dataset_ori = None
else:
Expand All @@ -155,6 +160,7 @@ def __init__( # noqa: D107
mask_token_prob=self.mask_token_prob,
random_token_prob=self.random_token_prob,
seed=random_utils.get_seed_from_rng(rng),
include_unrecognized_vocab_in_dataset=include_unrecognized_vocab_in_dataset,
)
self._train_dataset_ori = None
self._val_dataset_ori = None
Expand Down
Loading
Loading