diff --git a/README.md b/README.md index 018eee517..1299e5f5b 100644 --- a/README.md +++ b/README.md @@ -111,6 +111,23 @@ graphium-train --config-path [PATH] --config-name [CONFIG] ``` Thanks to the modular nature of `hydra` you can reuse many of our config settings for your own experiments with Graphium. +## Preparing the data in advance +The data preparation including the featurization (e.g., of molecules from smiles to pyg-compatible format) is embedded in the pipeline and will be performed when executing `graphium-train [...]`. + +However, when working with larger datasets, it is recommended to perform data preparation in advance using a machine with sufficient allocated memory (e.g., ~400GB in the case of `LargeMix`). Preparing data in advance is also beneficial when running lots of concurrent jobs with identical molecular featurization, so that resources aren't wasted and processes don't conflict reading/writing in the same directory. + +The following command-line will prepare the data and cache it, then use it to train a model. +```bash +# First prepare the data and cache it in `path_to_cached_data` +graphium-prepare-data datamodule.args.processed_graph_data_path=[path_to_cached_data] + +# Then train the model on the prepared data +graphium-train [...] datamodule.args.processed_graph_data_path=[path_to_cached_data] +``` + +**Note** that `datamodule.args.processed_graph_data_path` can also be specified at `expts/hydra_configs/`. + +**Note** that, every time the configs of `datamodule.args.featurization` changes, you will need to run a new data preparation, which will automatically be saved in a separate directory that uses a hash unique to the configs. ## First Time Running on IPUs For new IPU developers this section helps provide some more explanation on how to set up an environment to use Graphcore IPUs with Graphium. diff --git a/docs/tutorials/feature_processing/choosing_parallelization.ipynb b/docs/tutorials/feature_processing/choosing_parallelization.ipynb index 1ebb54451..0ab569d57 100644 --- a/docs/tutorials/feature_processing/choosing_parallelization.ipynb +++ b/docs/tutorials/feature_processing/choosing_parallelization.ipynb @@ -14,7 +14,7 @@ }, { "cell_type": "code", - "execution_count": 1, + "execution_count": 3, "id": "b5df2ac6-2ded-4597-a445-f2b5fb106330", "metadata": { "tags": [] @@ -24,8 +24,8 @@ "name": "stdout", "output_type": "stream", "text": [ - "INFO: Pandarallel will run on 240 workers.\n", - "INFO: Pandarallel will use Memory file system to transfer data between the main process and workers.\n" + "The autoreload extension is already loaded. To reload it, use:\n", + " %reload_ext autoreload\n" ] } ], @@ -39,9 +39,9 @@ "import datamol as dm\n", "import pandas as pd\n", "\n", - "from pandarallel import pandarallel\n", + "# from pandarallel import pandarallel\n", "\n", - "pandarallel.initialize(progress_bar=True, nb_workers=joblib.cpu_count())" + "# pandarallel.initialize(progress_bar=True, nb_workers=joblib.cpu_count())" ] }, { @@ -54,7 +54,7 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": 4, "id": "0f31e18d-bdd9-4d9b-8ba5-81e5887b857e", "metadata": { "tags": [] @@ -70,7 +70,7 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 7, "id": "a1197c31-7dbc-4fd7-a69a-5215e1a96b8e", "metadata": { "tags": [] @@ -109,7 +109,7 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 10, "id": "2f8ce5c3-4232-4279-8ea3-7a74832303be", "metadata": { "tags": [] @@ -129,7 +129,7 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 11, "id": "a246cdcf-b5ea-4c9e-9ccc-dd3c544587bb", "metadata": { "tags": [] @@ -138,7 +138,7 @@ { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "3e939cd3a24742038b804bbfd961377d", + "model_id": "cc396220c7144c8d8b195fb87694bbfe", "version_major": 2, "version_minor": 0 }, @@ -489,7 +489,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.8.10" + "version": "3.10.12" }, "widgets": { "application/vnd.jupyter.widget-state+json": { diff --git a/expts/configs/config_gps_10M_pcqm4m.yaml b/expts/configs/config_gps_10M_pcqm4m.yaml index 0b0dff7dc..10faa3b1e 100644 --- a/expts/configs/config_gps_10M_pcqm4m.yaml +++ b/expts/configs/config_gps_10M_pcqm4m.yaml @@ -112,7 +112,6 @@ datamodule: pos_type: rw_return_probs ksteps: 16 - # cache_data_path: . num_workers: 0 # -1 to use all persistent_workers: False # if use persistent worker at the start of each epoch. # Using persistent_workers false might make the start of each epoch very long. diff --git a/expts/configs/config_gps_10M_pcqm4m_mod.yaml b/expts/configs/config_gps_10M_pcqm4m_mod.yaml index 1c9f6da31..e2cdb44c2 100644 --- a/expts/configs/config_gps_10M_pcqm4m_mod.yaml +++ b/expts/configs/config_gps_10M_pcqm4m_mod.yaml @@ -81,7 +81,6 @@ datamodule: # Data handling-related batch_size_training: 64 batch_size_inference: 16 - # cache_data_path: . num_workers: 0 # -1 to use all persistent_workers: False # if use persistent worker at the start of each epoch. # Using persistent_workers false might make the start of each epoch very long. diff --git a/expts/configs/config_mpnn_10M_b3lyp.yaml b/expts/configs/config_mpnn_10M_b3lyp.yaml index dca4cd540..c385d7689 100644 --- a/expts/configs/config_mpnn_10M_b3lyp.yaml +++ b/expts/configs/config_mpnn_10M_b3lyp.yaml @@ -93,6 +93,7 @@ datamodule: featurization_progress: True featurization_backend: "loky" processed_graph_data_path: "../datacache/b3lyp/" + dataloading_from: ram featurization: # OGB: ['atomic_num', 'degree', 'possible_formal_charge', 'possible_numH' (total-valence), # 'possible_number_radical_e', 'possible_is_aromatic', 'possible_is_in_ring', @@ -123,7 +124,6 @@ datamodule: pos_type: rw_return_probs ksteps: 16 - # cache_data_path: . num_workers: 0 # -1 to use all persistent_workers: False # if use persistent worker at the start of each epoch. # Using persistent_workers false might make the start of each epoch very long. diff --git a/expts/configs/config_mpnn_pcqm4m.yaml b/expts/configs/config_mpnn_pcqm4m.yaml index 4e34f89ea..9735f9555 100644 --- a/expts/configs/config_mpnn_pcqm4m.yaml +++ b/expts/configs/config_mpnn_pcqm4m.yaml @@ -30,8 +30,8 @@ datamodule: featurization_n_jobs: 20 featurization_progress: True featurization_backend: "loky" - cache_data_path: "./datacache" processed_graph_data_path: "graphium/data/PCQM4Mv2/" + dataloading_from: ram featurization: # OGB: ['atomic_num', 'degree', 'possible_formal_charge', 'possible_numH' (total-valence), # 'possible_number_radical_e', 'possible_is_aromatic', 'possible_is_in_ring', @@ -58,7 +58,6 @@ datamodule: # Data handling-related batch_size_training: 64 batch_size_inference: 16 - # cache_data_path: . num_workers: 40 # -1 to use all persistent_workers: False # if use persistent worker at the start of each epoch. # Using persistent_workers false might make the start of each epoch very long. diff --git a/expts/hydra-configs/architecture/toymix.yaml b/expts/hydra-configs/architecture/toymix.yaml index 6927f4e66..c79325919 100644 --- a/expts/hydra-configs/architecture/toymix.yaml +++ b/expts/hydra-configs/architecture/toymix.yaml @@ -79,6 +79,7 @@ datamodule: featurization_progress: True featurization_backend: "loky" processed_graph_data_path: "../datacache/neurips2023-small/" + dataloading_from: ram num_workers: 30 # -1 to use all persistent_workers: False featurization: diff --git a/expts/neurips2023_configs/base_config/large.yaml b/expts/neurips2023_configs/base_config/large.yaml index db2b5dbb6..5ba023b3e 100644 --- a/expts/neurips2023_configs/base_config/large.yaml +++ b/expts/neurips2023_configs/base_config/large.yaml @@ -168,7 +168,6 @@ datamodule: pos_type: rw_return_probs ksteps: 16 - # cache_data_path: . num_workers: 32 # -1 to use all persistent_workers: True # if use persistent worker at the start of each epoch. # Using persistent_workers false might make the start of each epoch very long. diff --git a/expts/neurips2023_configs/base_config/small.yaml b/expts/neurips2023_configs/base_config/small.yaml index 2e63477a1..fd7ce3fbe 100644 --- a/expts/neurips2023_configs/base_config/small.yaml +++ b/expts/neurips2023_configs/base_config/small.yaml @@ -132,7 +132,6 @@ datamodule: pos_type: rw_return_probs ksteps: 16 - # cache_data_path: . num_workers: 30 # -1 to use all persistent_workers: False # if use persistent worker at the start of each epoch. # Using persistent_workers false might make the start of each epoch very long. diff --git a/expts/neurips2023_configs/baseline/config_small_gcn_baseline.yaml b/expts/neurips2023_configs/baseline/config_small_gcn_baseline.yaml index 401dcabd6..7b2d2cbdf 100644 --- a/expts/neurips2023_configs/baseline/config_small_gcn_baseline.yaml +++ b/expts/neurips2023_configs/baseline/config_small_gcn_baseline.yaml @@ -131,7 +131,6 @@ datamodule: pos_type: rw_return_probs ksteps: 16 - # cache_data_path: . num_workers: 30 # -1 to use all persistent_workers: False # if use persistent worker at the start of each epoch. # Using persistent_workers false might make the start of each epoch very long. diff --git a/expts/neurips2023_configs/config_classifigression_l1000.yaml b/expts/neurips2023_configs/config_classifigression_l1000.yaml index 37d83736f..48f06d9d1 100644 --- a/expts/neurips2023_configs/config_classifigression_l1000.yaml +++ b/expts/neurips2023_configs/config_classifigression_l1000.yaml @@ -111,7 +111,6 @@ datamodule: pos_type: rw_return_probs ksteps: 16 - # cache_data_path: . num_workers: 5 # -1 to use all persistent_workers: False # if use persistent worker at the start of each epoch. # Using persistent_workers false might make the start of each epoch very long. diff --git a/expts/neurips2023_configs/config_luis_jama.yaml b/expts/neurips2023_configs/config_luis_jama.yaml index 46ec4c4c0..5135c5cae 100644 --- a/expts/neurips2023_configs/config_luis_jama.yaml +++ b/expts/neurips2023_configs/config_luis_jama.yaml @@ -119,7 +119,6 @@ datamodule: pos_type: rw_return_probs ksteps: 16 - # cache_data_path: . num_workers: 4 # -1 to use all persistent_workers: False # if use persistent worker at the start of each epoch. # Using persistent_workers false might make the start of each epoch very long. diff --git a/expts/neurips2023_configs/debug/config_debug.yaml b/expts/neurips2023_configs/debug/config_debug.yaml index a323427e5..3d31e5e8c 100644 --- a/expts/neurips2023_configs/debug/config_debug.yaml +++ b/expts/neurips2023_configs/debug/config_debug.yaml @@ -105,7 +105,6 @@ datamodule: pos_type: rw_return_probs ksteps: 16 - # cache_data_path: . num_workers: 0 # -1 to use all persistent_workers: False # if use persistent worker at the start of each epoch. # Using persistent_workers false might make the start of each epoch very long. diff --git a/expts/neurips2023_configs/debug/config_large_gcn_debug.yaml b/expts/neurips2023_configs/debug/config_large_gcn_debug.yaml index 5fe6d8741..ec05bf6eb 100644 --- a/expts/neurips2023_configs/debug/config_large_gcn_debug.yaml +++ b/expts/neurips2023_configs/debug/config_large_gcn_debug.yaml @@ -166,7 +166,6 @@ datamodule: pos_type: rw_return_probs ksteps: 16 - # cache_data_path: . num_workers: 30 # -1 to use all persistent_workers: False # if use persistent worker at the start of each epoch. # Using persistent_workers false might make the start of each epoch very long. @@ -327,7 +326,7 @@ predictor: l1000_mcf7: [] pcba_1328: [] pcqm4m_g25: [] - pcqm4m_n4: [] + pcqm4m_n4: [] loss_fun: l1000_vcap: name: hybrid_ce_ipu diff --git a/expts/neurips2023_configs/debug/config_small_gcn_debug.yaml b/expts/neurips2023_configs/debug/config_small_gcn_debug.yaml index 717ae0675..26b50756f 100644 --- a/expts/neurips2023_configs/debug/config_small_gcn_debug.yaml +++ b/expts/neurips2023_configs/debug/config_small_gcn_debug.yaml @@ -119,7 +119,6 @@ datamodule: pos_type: rw_return_probs ksteps: 16 - # cache_data_path: . num_workers: 30 # -1 to use all persistent_workers: False # if use persistent worker at the start of each epoch. # Using persistent_workers false might make the start of each epoch very long. diff --git a/expts/neurips2023_configs/single_task_gcn/config_large_gcn_mcf7.yaml b/expts/neurips2023_configs/single_task_gcn/config_large_gcn_mcf7.yaml index f73d4b08c..e05d1be8d 100644 --- a/expts/neurips2023_configs/single_task_gcn/config_large_gcn_mcf7.yaml +++ b/expts/neurips2023_configs/single_task_gcn/config_large_gcn_mcf7.yaml @@ -100,7 +100,6 @@ datamodule: pos_type: rw_return_probs ksteps: 16 - # cache_data_path: . num_workers: 30 # -1 to use all persistent_workers: False # if use persistent worker at the start of each epoch. # Using persistent_workers false might make the start of each epoch very long. diff --git a/expts/neurips2023_configs/single_task_gcn/config_large_gcn_pcba.yaml b/expts/neurips2023_configs/single_task_gcn/config_large_gcn_pcba.yaml index 3985f26a7..cf924850e 100644 --- a/expts/neurips2023_configs/single_task_gcn/config_large_gcn_pcba.yaml +++ b/expts/neurips2023_configs/single_task_gcn/config_large_gcn_pcba.yaml @@ -100,7 +100,6 @@ datamodule: pos_type: rw_return_probs ksteps: 16 - # cache_data_path: . num_workers: 30 # -1 to use all persistent_workers: False # if use persistent worker at the start of each epoch. # Using persistent_workers false might make the start of each epoch very long. diff --git a/expts/neurips2023_configs/single_task_gcn/config_large_gcn_vcap.yaml b/expts/neurips2023_configs/single_task_gcn/config_large_gcn_vcap.yaml index dad3893a9..f1c9bcfd4 100644 --- a/expts/neurips2023_configs/single_task_gcn/config_large_gcn_vcap.yaml +++ b/expts/neurips2023_configs/single_task_gcn/config_large_gcn_vcap.yaml @@ -100,7 +100,6 @@ datamodule: pos_type: rw_return_probs ksteps: 16 - # cache_data_path: . num_workers: 30 # -1 to use all persistent_workers: False # if use persistent worker at the start of each epoch. # Using persistent_workers false might make the start of each epoch very long. diff --git a/expts/neurips2023_configs/single_task_gin/config_large_gin_g25.yaml b/expts/neurips2023_configs/single_task_gin/config_large_gin_g25.yaml index 20c6aaa37..01988e527 100644 --- a/expts/neurips2023_configs/single_task_gin/config_large_gin_g25.yaml +++ b/expts/neurips2023_configs/single_task_gin/config_large_gin_g25.yaml @@ -103,7 +103,6 @@ datamodule: pos_type: rw_return_probs ksteps: 16 - # cache_data_path: . num_workers: 30 # -1 to use all persistent_workers: False # if use persistent worker at the start of each epoch. # Using persistent_workers false might make the start of each epoch very long. diff --git a/expts/neurips2023_configs/single_task_gin/config_large_gin_mcf7.yaml b/expts/neurips2023_configs/single_task_gin/config_large_gin_mcf7.yaml index 12974d9e4..fdeb4b399 100644 --- a/expts/neurips2023_configs/single_task_gin/config_large_gin_mcf7.yaml +++ b/expts/neurips2023_configs/single_task_gin/config_large_gin_mcf7.yaml @@ -100,7 +100,6 @@ datamodule: pos_type: rw_return_probs ksteps: 16 - # cache_data_path: . num_workers: 30 # -1 to use all persistent_workers: False # if use persistent worker at the start of each epoch. # Using persistent_workers false might make the start of each epoch very long. diff --git a/expts/neurips2023_configs/single_task_gin/config_large_gin_n4.yaml b/expts/neurips2023_configs/single_task_gin/config_large_gin_n4.yaml index 72320f137..5920a80f6 100644 --- a/expts/neurips2023_configs/single_task_gin/config_large_gin_n4.yaml +++ b/expts/neurips2023_configs/single_task_gin/config_large_gin_n4.yaml @@ -104,7 +104,6 @@ datamodule: pos_type: rw_return_probs ksteps: 16 - # cache_data_path: . num_workers: 30 # -1 to use all persistent_workers: False # if use persistent worker at the start of each epoch. # Using persistent_workers false might make the start of each epoch very long. diff --git a/expts/neurips2023_configs/single_task_gin/config_large_gin_pcba.yaml b/expts/neurips2023_configs/single_task_gin/config_large_gin_pcba.yaml index 1d9601ee1..de2f7fbc4 100644 --- a/expts/neurips2023_configs/single_task_gin/config_large_gin_pcba.yaml +++ b/expts/neurips2023_configs/single_task_gin/config_large_gin_pcba.yaml @@ -100,7 +100,6 @@ datamodule: pos_type: rw_return_probs ksteps: 16 - # cache_data_path: . num_workers: 30 # -1 to use all persistent_workers: False # if use persistent worker at the start of each epoch. # Using persistent_workers false might make the start of each epoch very long. diff --git a/expts/neurips2023_configs/single_task_gin/config_large_gin_pcq.yaml b/expts/neurips2023_configs/single_task_gin/config_large_gin_pcq.yaml index 85fce8e13..ca820e86b 100644 --- a/expts/neurips2023_configs/single_task_gin/config_large_gin_pcq.yaml +++ b/expts/neurips2023_configs/single_task_gin/config_large_gin_pcq.yaml @@ -118,7 +118,6 @@ datamodule: pos_type: rw_return_probs ksteps: 16 - # cache_data_path: . num_workers: 30 # -1 to use all persistent_workers: False # if use persistent worker at the start of each epoch. # Using persistent_workers false might make the start of each epoch very long. diff --git a/expts/neurips2023_configs/single_task_gin/config_large_gin_vcap.yaml b/expts/neurips2023_configs/single_task_gin/config_large_gin_vcap.yaml index c52a041f1..c21b765b3 100644 --- a/expts/neurips2023_configs/single_task_gin/config_large_gin_vcap.yaml +++ b/expts/neurips2023_configs/single_task_gin/config_large_gin_vcap.yaml @@ -100,7 +100,6 @@ datamodule: pos_type: rw_return_probs ksteps: 16 - # cache_data_path: . num_workers: 30 # -1 to use all persistent_workers: False # if use persistent worker at the start of each epoch. # Using persistent_workers false might make the start of each epoch very long. diff --git a/expts/neurips2023_configs/single_task_gine/config_large_gine_g25.yaml b/expts/neurips2023_configs/single_task_gine/config_large_gine_g25.yaml index 4ab892b00..b88314797 100644 --- a/expts/neurips2023_configs/single_task_gine/config_large_gine_g25.yaml +++ b/expts/neurips2023_configs/single_task_gine/config_large_gine_g25.yaml @@ -103,7 +103,6 @@ datamodule: pos_type: rw_return_probs ksteps: 16 - # cache_data_path: . num_workers: 30 # -1 to use all persistent_workers: False # if use persistent worker at the start of each epoch. # Using persistent_workers false might make the start of each epoch very long. diff --git a/expts/neurips2023_configs/single_task_gine/config_large_gine_mcf7.yaml b/expts/neurips2023_configs/single_task_gine/config_large_gine_mcf7.yaml index 8605121f1..b96fc8daf 100644 --- a/expts/neurips2023_configs/single_task_gine/config_large_gine_mcf7.yaml +++ b/expts/neurips2023_configs/single_task_gine/config_large_gine_mcf7.yaml @@ -100,7 +100,6 @@ datamodule: pos_type: rw_return_probs ksteps: 16 - # cache_data_path: . num_workers: 30 # -1 to use all persistent_workers: False # if use persistent worker at the start of each epoch. # Using persistent_workers false might make the start of each epoch very long. diff --git a/expts/neurips2023_configs/single_task_gine/config_large_gine_n4.yaml b/expts/neurips2023_configs/single_task_gine/config_large_gine_n4.yaml index ee89b6012..e98ae03da 100644 --- a/expts/neurips2023_configs/single_task_gine/config_large_gine_n4.yaml +++ b/expts/neurips2023_configs/single_task_gine/config_large_gine_n4.yaml @@ -104,7 +104,6 @@ datamodule: pos_type: rw_return_probs ksteps: 16 - # cache_data_path: . num_workers: 30 # -1 to use all persistent_workers: False # if use persistent worker at the start of each epoch. # Using persistent_workers false might make the start of each epoch very long. diff --git a/expts/neurips2023_configs/single_task_gine/config_large_gine_pcba.yaml b/expts/neurips2023_configs/single_task_gine/config_large_gine_pcba.yaml index 42ac474e9..427f7ca0f 100644 --- a/expts/neurips2023_configs/single_task_gine/config_large_gine_pcba.yaml +++ b/expts/neurips2023_configs/single_task_gine/config_large_gine_pcba.yaml @@ -100,7 +100,6 @@ datamodule: pos_type: rw_return_probs ksteps: 16 - # cache_data_path: . num_workers: 30 # -1 to use all persistent_workers: False # if use persistent worker at the start of each epoch. # Using persistent_workers false might make the start of each epoch very long. diff --git a/expts/neurips2023_configs/single_task_gine/config_large_gine_pcq.yaml b/expts/neurips2023_configs/single_task_gine/config_large_gine_pcq.yaml index 84bd5c66b..07fc6d009 100644 --- a/expts/neurips2023_configs/single_task_gine/config_large_gine_pcq.yaml +++ b/expts/neurips2023_configs/single_task_gine/config_large_gine_pcq.yaml @@ -118,7 +118,6 @@ datamodule: pos_type: rw_return_probs ksteps: 16 - # cache_data_path: . num_workers: 30 # -1 to use all persistent_workers: False # if use persistent worker at the start of each epoch. # Using persistent_workers false might make the start of each epoch very long. diff --git a/expts/neurips2023_configs/single_task_gine/config_large_gine_vcap.yaml b/expts/neurips2023_configs/single_task_gine/config_large_gine_vcap.yaml index 09d23fb92..b63263b3d 100644 --- a/expts/neurips2023_configs/single_task_gine/config_large_gine_vcap.yaml +++ b/expts/neurips2023_configs/single_task_gine/config_large_gine_vcap.yaml @@ -100,7 +100,6 @@ datamodule: pos_type: rw_return_probs ksteps: 16 - # cache_data_path: . num_workers: 30 # -1 to use all persistent_workers: False # if use persistent worker at the start of each epoch. # Using persistent_workers false might make the start of each epoch very long. diff --git a/graphium/cli/prepare_data.py b/graphium/cli/prepare_data.py new file mode 100644 index 000000000..7a8c6eceb --- /dev/null +++ b/graphium/cli/prepare_data.py @@ -0,0 +1,42 @@ +import hydra +import timeit + +from omegaconf import DictConfig, OmegaConf +from loguru import logger + +from graphium.config._loader import load_datamodule, load_accelerator + + +@hydra.main(version_base=None, config_path="../../expts/hydra-configs", config_name="main") +def cli(cfg: DictConfig) -> None: + """ + CLI endpoint for preparing the data in advance. + """ + run_prepare_data(cfg) + + +def run_prepare_data(cfg: DictConfig) -> None: + """ + The main (pre-)training and fine-tuning loop. + """ + + cfg = OmegaConf.to_container(cfg, resolve=True) + st = timeit.default_timer() + + # Checking that `processed_graph_data_path` is provided + path = cfg["datamodule"]["args"].get("processed_graph_data_path", None) + if path is None: + raise ValueError( + "Please provide `datamodule.args.processed_graph_data_path` to specify the caching dir." + ) + logger.info(f"The caching dir is set to '{path}'") + + # Data-module + datamodule = load_datamodule(cfg, "cpu") + datamodule.prepare_data() + + logger.info(f"Data preparation took {timeit.default_timer() - st:.2f} seconds.") + + +if __name__ == "__main__": + cli() diff --git a/graphium/config/fake_and_missing_multilevel_multitask_pyg.yaml b/graphium/config/fake_and_missing_multilevel_multitask_pyg.yaml index 48d55a501..044a0129c 100644 --- a/graphium/config/fake_and_missing_multilevel_multitask_pyg.yaml +++ b/graphium/config/fake_and_missing_multilevel_multitask_pyg.yaml @@ -58,7 +58,6 @@ datamodule: # Data handling-related batch_size_training: 16 batch_size_inference: 16 - # cache_data_path: null architecture: # The parameters for the full graph network are taken from `config_micro_ZINC.yaml` model_type: FullGraphMultiTaskNetwork @@ -111,7 +110,7 @@ architecture: # The parameters for the full graph network are taken from `co dropout: *dropout normalization: *normalization last_normalization: "none" - residual_type: none + residual_type: none graph: pooling: [sum, max] out_dim: 1 @@ -122,7 +121,7 @@ architecture: # The parameters for the full graph network are taken from `co dropout: *dropout normalization: *normalization last_normalization: "none" - residual_type: none + residual_type: none edge: out_dim: 16 hidden_dims: 32 @@ -132,7 +131,7 @@ architecture: # The parameters for the full graph network are taken from `co dropout: *dropout normalization: *normalization last_normalization: "none" - residual_type: none + residual_type: none nodepair: out_dim: 16 hidden_dims: 32 @@ -142,7 +141,7 @@ architecture: # The parameters for the full graph network are taken from `co dropout: *dropout normalization: *normalization last_normalization: "none" - residual_type: none + residual_type: none task_heads: # Set as null to avoid task heads. Recall that the arguments for the TaskHeads is a List of TaskHeadParams task_1: diff --git a/graphium/config/fake_multilevel_multitask_pyg.yaml b/graphium/config/fake_multilevel_multitask_pyg.yaml index 3ca5085f9..918807cb4 100644 --- a/graphium/config/fake_multilevel_multitask_pyg.yaml +++ b/graphium/config/fake_multilevel_multitask_pyg.yaml @@ -58,7 +58,6 @@ datamodule: # Data handling-related batch_size_training: 16 batch_size_inference: 16 - # cache_data_path: null architecture: # The parameters for the full graph network are taken from `config_micro_ZINC.yaml` model_type: FullGraphMultiTaskNetwork @@ -111,7 +110,7 @@ architecture: # The parameters for the full graph network are taken from `co dropout: *dropout normalization: *normalization last_normalization: "none" - residual_type: none + residual_type: none graph: pooling: [sum, max] out_dim: 1 @@ -122,7 +121,7 @@ architecture: # The parameters for the full graph network are taken from `co dropout: *dropout normalization: *normalization last_normalization: "none" - residual_type: none + residual_type: none edge: out_dim: 16 hidden_dims: 32 @@ -132,7 +131,7 @@ architecture: # The parameters for the full graph network are taken from `co dropout: *dropout normalization: *normalization last_normalization: "none" - residual_type: none + residual_type: none nodepair: out_dim: 16 hidden_dims: 32 @@ -142,7 +141,7 @@ architecture: # The parameters for the full graph network are taken from `co dropout: *dropout normalization: *normalization last_normalization: "none" - residual_type: none + residual_type: none task_heads: # Set as null to avoid task heads. Recall that the arguments for the TaskHeads is a List of TaskHeadParams task_1: diff --git a/graphium/config/zinc_default_multitask_pyg.yaml b/graphium/config/zinc_default_multitask_pyg.yaml index 192d2c4ef..07ae4bf9b 100644 --- a/graphium/config/zinc_default_multitask_pyg.yaml +++ b/graphium/config/zinc_default_multitask_pyg.yaml @@ -58,7 +58,6 @@ datamodule: # Data handling-related batch_size_training: 16 batch_size_inference: 16 - # cache_data_path: null architecture: # The parameters for the full graph network are taken from `config_micro_ZINC.yaml` model_type: FullGraphMultiTaskNetwork diff --git a/graphium/data/datamodule.py b/graphium/data/datamodule.py index b85fd664c..e8cab271d 100644 --- a/graphium/data/datamodule.py +++ b/graphium/data/datamodule.py @@ -1,6 +1,7 @@ import tempfile from contextlib import redirect_stderr, redirect_stdout from typing import Type, List, Dict, Union, Any, Callable, Optional, Tuple, Iterable, Literal +from os import PathLike as Path from dataclasses import dataclass @@ -135,6 +136,7 @@ def __init__( self._predict_ds = None self._data_is_prepared = False + self._data_is_cached = False def prepare_data(self): raise NotImplementedError() @@ -770,8 +772,8 @@ class MultitaskFromSmilesDataModule(BaseDataModule, IPUDataModuleModifier): def __init__( self, task_specific_args: Union[DatasetProcessingParams, Dict[str, Any]], - cache_data_path: Optional[Union[str, os.PathLike]] = None, processed_graph_data_path: Optional[Union[str, os.PathLike]] = None, + dataloading_from: str = "ram", featurization: Optional[Union[Dict[str, Any], omegaconf.DictConfig]] = None, batch_size_training: int = 16, batch_size_inference: int = 16, @@ -835,8 +837,11 @@ def __init__( task_splits_path: (value) A path a CSV file containing indices for the splits. The file must contains 3 columns "train", "val" and "test". It takes precedence over `split_val` and `split_test`. - cache_data_path: path where to save or reload the cached data. The path can be - remote (S3, GS, etc). + processed_graph_data_path: path where to save or reload the cached data. Can be used + to avoid recomputing the featurization, or for dataloading from disk with the option `dataloader_from="disk"`. + dataloading_from: Whether to load the data from RAM or from disk. If set to "disk", the data + must have been previously cached with `processed_graph_data_path` set. If set to "ram", the data + will be loaded in RAM and the `processed_graph_data_path` will be ignored. featurization: args to apply to the SMILES to Graph featurizer. batch_size_training: batch size for training and val dataset. batch_size_inference: batch size for test dataset. @@ -909,10 +914,7 @@ def __init__( self.val_ds = None self.test_ds = None - self.cache_data_path = cache_data_path - self.processed_graph_data_path = processed_graph_data_path - - self.load_from_file = processed_graph_data_path is not None + self._parse_caching_args(processed_graph_data_path, dataloading_from) self.task_norms = {} @@ -932,6 +934,32 @@ def __init__( ) self.data_hash = self.get_data_hash() + if self.processed_graph_data_path is not None: + if self._ready_to_load_all_from_file(): + self._data_is_prepared = True + self._data_is_cached = True + + def _parse_caching_args(self, processed_graph_data_path, dataloading_from): + """ + Parse the caching arguments, and raise errors if the arguments are invalid. + """ + + # Whether to load the data from RAM or from disk + dataloading_from = dataloading_from.lower() + if dataloading_from not in ["disk", "ram"]: + raise ValueError( + f"`dataloading_from` should be either 'disk' or 'ram', Provided: `{dataloading_from}`" + ) + + # If loading from disk, the path to the cached data must be provided + if dataloading_from == "disk" and processed_graph_data_path is None: + raise ValueError( + "When `dataloading_from` is 'disk', `processed_graph_data_path` must be provided." + ) + + self.processed_graph_data_path = processed_graph_data_path + self.dataloading_from = dataloading_from + def _get_task_key(self, task_level: str, task: str): task_prefix = f"{task_level}_" if not task.startswith(task_prefix): @@ -948,7 +976,7 @@ def get_task_levels(self): return task_level_map - def prepare_data(self): + def prepare_data(self, save_smiles_and_ids: bool = False): """Called only from a single process in distributed settings. Steps: - If each cache is set and exists, reload from cache and return. Otherwise, @@ -973,26 +1001,10 @@ def has_atoms_after_h_removal(smiles): return has_atoms if self._data_is_prepared: - logger.info("Data is already prepared. Skipping the preparation") + logger.info("Data is already prepared.") + self.get_label_statistics(self.processed_graph_data_path, self.data_hash, dataset=None) return - if self.load_from_file: - if self._ready_to_load_all_from_file(): - self.get_label_statistics(self.processed_graph_data_path, self.data_hash, dataset=None) - self._data_is_prepared = True - return - - else: - # If a path for data caching is provided, try to load from the path. - # If successful, skip the data preparation. - # For next task: load the single graph files for train, val and test data - cache_data_exists = self.load_data_from_cache() - # need to check if cache exist properly - if cache_data_exists: - self.get_label_statistics(self.cache_data_path, self.data_hash, dataset=None) - self._data_is_prepared = True - return - """Load all single-task dataframes.""" task_df = {} for task, args in self.task_dataset_processing_params.items(): @@ -1160,12 +1172,9 @@ def has_atoms_after_h_removal(smiles): self.single_task_datasets, self.task_train_indices, self.task_val_indices, self.task_test_indices ) - if self.load_from_file: - self._save_data_to_files() - - # When a cache path is provided but no cache is found, save to cache - elif (self.cache_data_path is not None) and (not cache_data_exists): - self.save_data_to_cache() + if self.processed_graph_data_path is not None: + self._save_data_to_files(save_smiles_and_ids) + self._data_is_cached = True self._data_is_prepared = True @@ -1185,21 +1194,15 @@ def setup( labels_size = {} labels_dtype = {} if stage == "fit" or stage is None: - if self.load_from_file: - processed_train_data_path = self._path_to_load_from_file("train") - assert self._data_ready_at_path( - processed_train_data_path - ), "Loading from file + setup() called but training data not ready" - processed_val_data_path = self._path_to_load_from_file("val") - assert self._data_ready_at_path( - processed_val_data_path - ), "Loading from file + setup() called but validation data not ready" - else: - processed_train_data_path = None - processed_val_data_path = None + # if self.train_ds is None: + self.train_ds = self._make_multitask_dataset( + self.dataloading_from, "train", save_smiles_and_ids=save_smiles_and_ids + ) + # if self.val_ds is None: + self.val_ds = self._make_multitask_dataset( + self.dataloading_from, "val", save_smiles_and_ids=save_smiles_and_ids + ) - self.train_ds = self._make_multitask_dataset("train", save_smiles_and_ids=save_smiles_and_ids) - self.val_ds = self._make_multitask_dataset("val", save_smiles_and_ids=save_smiles_and_ids) logger.info(self.train_ds) logger.info(self.val_ds) labels_size.update( @@ -1210,14 +1213,11 @@ def setup( labels_dtype.update(self.val_ds.labels_dtype) if stage == "test" or stage is None: - if self.load_from_file: - processed_test_data_path = self._path_to_load_from_file("test") - assert self._data_ready_at_path( - processed_test_data_path - ), "Loading from file + setup() called but test data not ready" - else: - processed_test_data_path = None - self.test_ds = self._make_multitask_dataset("test", save_smiles_and_ids=save_smiles_and_ids) + # if self.test_ds is None: + self.test_ds = self._make_multitask_dataset( + self.dataloading_from, "test", save_smiles_and_ids=save_smiles_and_ids + ) + logger.info(self.test_ds) labels_size.update(self.test_ds.labels_size) @@ -1235,9 +1235,9 @@ def setup( def _make_multitask_dataset( self, + dataloading_from: Literal["disk", "ram"], stage: Literal["train", "val", "test"], save_smiles_and_ids: bool, - load_from_file: Optional[bool] = None, ) -> Datasets.MultitaskDataset: """ Create a MultitaskDataset for the given stage using single task datasets @@ -1246,8 +1246,7 @@ def _make_multitask_dataset( Parameters: stage: Stage to create multitask dataset for save_smiles_and_ids: Whether to save SMILES strings and unique IDs - data_path: path to load from if loading from file - load_from_file: whether to load from file. If `None`, defers to `self.load_from_file` + processed_graph_data_path: path to save and load processed graph data from """ allowed_stages = ["train", "val", "test"] @@ -1265,18 +1264,7 @@ def _make_multitask_dataset( else: raise ValueError(f"Unknown stage {stage}") - if load_from_file is None: - load_from_file = self.load_from_file - - # assert singletask_datasets is not None, "Single task datasets must exist to make multitask dataset" - if singletask_datasets is None: - assert load_from_file - assert self._data_ready_at_path( - self._path_to_load_from_file(stage) - ), "Trying to create multitask dataset without single-task datasets but data not ready" - files_ready = True - else: - files_ready = False + processed_graph_data_path = self.processed_graph_data_path multitask_dataset = Datasets.MultitaskDataset( singletask_datasets, @@ -1286,9 +1274,9 @@ def _make_multitask_dataset( progress=self.featurization_progress, about=about, save_smiles_and_ids=save_smiles_and_ids, - data_path=self._path_to_load_from_file(stage) if load_from_file else None, - load_from_file=load_from_file, - files_ready=files_ready, + data_path=self._path_to_load_from_file(stage) if processed_graph_data_path else None, + dataloading_from=dataloading_from, + data_is_cached=self._data_is_cached, ) # type: ignore # calculate statistics for the train split and used for all splits normalization @@ -1296,7 +1284,8 @@ def _make_multitask_dataset( self.get_label_statistics( self.processed_graph_data_path, self.data_hash, multitask_dataset, train=True ) - if not load_from_file: + # Normalization has already been applied in cached data + if not self._data_is_prepared: self.normalize_label(multitask_dataset, stage) return multitask_dataset @@ -1327,7 +1316,7 @@ def _data_ready_at_path(self, path: str) -> bool: return can_load_from_file - def _save_data_to_files(self) -> None: + def _save_data_to_files(self, save_smiles_and_ids: bool = False) -> None: """ Save data to files so that they can be loaded from file during training/validation/test """ @@ -1337,7 +1326,9 @@ def _save_data_to_files(self) -> None: # At the moment, we need to merge the `SingleTaskDataset`'s into `MultitaskDataset`s in order to save to file # This is because the combined labels need to be stored together. We can investigate not doing this if this is a problem temp_datasets = { - stage: self._make_multitask_dataset(stage, save_smiles_and_ids=False, load_from_file=False) + stage: self._make_multitask_dataset( + dataloading_from="ram", stage=stage, save_smiles_and_ids=save_smiles_and_ids + ) for stage in stages } for stage in stages: @@ -1359,6 +1350,7 @@ def calculate_statistics(self, dataset: Datasets.MultitaskDataset, train: bool = train: whether the dataset is the training set """ + if self.task_norms and train: for task in dataset.labels_size.keys(): # if the label type is graph_*, we need to stack them as the tensor shape is (num_labels, ) @@ -2004,60 +1996,18 @@ def get_data_cache_fullname(self, compress: bool = False) -> str: Returns: full path to the data cache file """ - if self.cache_data_path is None: + if self.processed_graph_data_path is None: return ext = ".datacache" if compress: ext += ".gz" - data_cache_fullname = fs.join(self.cache_data_path, self.data_hash + ext) + data_cache_fullname = fs.join(self.processed_graph_data_path, self.data_hash + ext) return data_cache_fullname - def save_data_to_cache(self, verbose: bool = True, compress: bool = False) -> None: - """ - Save the datasets from cache. First create a hash for the dataset, use it to - generate a file name. Then save to the path given by `self.cache_data_path`. - - Parameters: - verbose: Whether to print the progress - compress: Whether to compress the data - - """ - full_cache_data_path = self.get_data_cache_fullname(compress=compress) - if full_cache_data_path is None: - logger.info("No cache data path specified. Skipping saving the data to cache.") - return - - save_params = { - "single_task_datasets": self.single_task_datasets, - "task_train_indices": self.task_train_indices, - "task_val_indices": self.task_val_indices, - "task_test_indices": self.task_test_indices, - } - - fs.mkdir(self.cache_data_path) - with fsspec.open(full_cache_data_path, mode="wb", compression="infer") as file: - if verbose: - logger.info(f"Saving the data to cache at path:\n`{full_cache_data_path}`") - now = time.time() - torch.save(save_params, file) - elapsed = round(time.time() - now) - if verbose: - logger.info( - f"Successfully saved the data to cache in {elapsed}s at path: `{full_cache_data_path}`" - ) - - # At the moment, we need to merge the `SingleTaskDataset`'s into `MultitaskDataset`s in order to save label stats - # This is because the combined labels need to be stored together. We can investigate not doing this if this is a problem - temp_train_dataset = self._make_multitask_dataset( - stage="train", save_smiles_and_ids=False, load_from_file=False - ) - - self.get_label_statistics(self.cache_data_path, self.data_hash, temp_train_dataset, train=True) - def load_data_from_cache(self, verbose: bool = True, compress: bool = False) -> bool: """ Load the datasets from cache. First create a hash for the dataset, and verify if that - hash is available at the path given by `self.cache_data_path`. + hash is available at the path given by `self.processed_graph_data_path`. Parameters: verbose: Whether to print the progress @@ -2193,7 +2143,8 @@ class GraphOGBDataModule(MultitaskFromSmilesDataModule): def __init__( self, task_specific_args: Dict[str, Union[DatasetProcessingParams, Dict[str, Any]]], - cache_data_path: Optional[Union[str, os.PathLike]] = None, + processed_graph_data_path: Optional[Union[str, os.PathLike]] = None, + dataloading_from: str = "ram", featurization: Optional[Union[Dict[str, Any], omegaconf.DictConfig]] = None, batch_size_training: int = 16, batch_size_inference: int = 16, @@ -2220,8 +2171,9 @@ def __init__( "ogbg-molhiv", "ogbg-molpcba", "ogbg-moltox21", "ogbg-molfreesolv". - "sample_size": The number of molecules to sample from the dataset. Default=None, meaning that all molecules will be considered. - cache_data_path: path where to save or reload the cached data. The path can be - remote (S3, GS, etc). + processed_graph_data_path: Path to the processed graph data. If None, the data will be + downloaded from the OGB website. + dataloading_from: Whether to load the data from RAM or disk. Default is "ram". featurization: args to apply to the SMILES to Graph featurizer. batch_size_training: batch size for training and val dataset. batch_size_inference: batch size for test dataset. @@ -2266,7 +2218,9 @@ def __init__( # Config for datamodule dm_args = {} dm_args["task_specific_args"] = new_task_specific_args - dm_args["cache_data_path"] = cache_data_path + dm_args["processed_graph_data_path"] = processed_graph_data_path + dm_args["dataloading_from"] = dataloading_from + dm_args["dataloader_from"] = dataloading_from dm_args["featurization"] = featurization dm_args["batch_size_training"] = batch_size_training dm_args["batch_size_inference"] = batch_size_inference @@ -2449,7 +2403,8 @@ def __init__( tdc_benchmark_names: Optional[Union[str, List[str]]] = None, tdc_train_val_seed: int = 0, # Inherited arguments from superclass - cache_data_path: Optional[Union[str, os.PathLike]] = None, + processed_graph_data_path: Optional[Union[str, Path]] = None, + dataloading_from: str = "ram", featurization: Optional[Union[Dict[str, Any], omegaconf.DictConfig]] = None, batch_size_training: int = 16, batch_size_inference: int = 16, @@ -2506,8 +2461,9 @@ def __init__( super().__init__( task_specific_args=task_specific_args, - cache_data_path=cache_data_path, featurization=featurization, + processed_graph_data_path=processed_graph_data_path, + dataloading_from=dataloading_from, batch_size_training=batch_size_training, batch_size_inference=batch_size_inference, batch_size_per_pack=batch_size_per_pack, @@ -2591,7 +2547,6 @@ class FakeDataModule(MultitaskFromSmilesDataModule): def __init__( self, task_specific_args: Dict[str, Dict[str, Any]], # TODO: Replace this with DatasetParams - cache_data_path: Optional[Union[str, os.PathLike]] = None, featurization: Optional[Union[Dict[str, Any], omegaconf.DictConfig]] = None, batch_size_training: int = 16, batch_size_inference: int = 16, @@ -2606,7 +2561,6 @@ def __init__( ): super().__init__( task_specific_args=task_specific_args, - cache_data_path=cache_data_path, featurization=featurization, batch_size_training=batch_size_training, batch_size_inference=batch_size_inference, diff --git a/graphium/data/dataset.py b/graphium/data/dataset.py index 180e3275f..039d1b35a 100644 --- a/graphium/data/dataset.py +++ b/graphium/data/dataset.py @@ -8,6 +8,8 @@ import os import numpy as np +from datamol import parallelized, parallelized_with_batches + import torch from torch.utils.data.dataloader import Dataset from torch_geometric.data import Data, Batch @@ -146,8 +148,8 @@ def __init__( save_smiles_and_ids: bool = False, about: str = "", data_path: Optional[Union[str, os.PathLike]] = None, - load_from_file: bool = False, - files_ready: bool = False, + dataloading_from: str = "ram", + data_is_cached: bool = False, ): r""" This class holds the information for the multitask dataset. @@ -169,27 +171,31 @@ def __init__( progress: Whether to display the progress bar save_smiles_and_ids: Whether to save the smiles and ids for the dataset. If `False`, `mol_ids` and `smiles` are set to `None` about: A description of the dataset - progress: Whether to display the progress bar - about: A description of the dataset data_path: The location of the data if saved on disk - load_from_file: Whether to load the data from disk - files_ready: Whether the files to load from were prepared ahead of time + dataloading_from: Whether to load the data from `"disk"` or `"ram"` + data_is_cached: Whether the data is already cached on `"disk"` """ super().__init__() - # self.datasets = datasets self.n_jobs = n_jobs self.backend = backend self.featurization_batch_size = featurization_batch_size self.progress = progress self.about = about + self.save_smiles_and_ids = save_smiles_and_ids self.data_path = data_path - self.load_from_file = load_from_file + self.dataloading_from = dataloading_from - if files_ready: - assert load_from_file + logger.info(f"Dataloading from {dataloading_from.upper()}") + + if data_is_cached: self._load_metadata() - self.features = None - self.labels = None + + if dataloading_from == "disk": + self.features = None + self.labels = None + elif dataloading_from == "ram": + logger.info("Transferring data from DISK to RAM...") + self.transfer_from_disk_to_ram() else: task = next(iter(datasets)) @@ -210,9 +216,46 @@ def __init__( if self.features is not None: self._num_nodes_list = get_num_nodes_per_graph(self.features) self._num_edges_list = get_num_edges_per_graph(self.features) - if self.load_from_file: - self.features = None - self.labels = None + + def transfer_from_disk_to_ram(self, parallel_with_batches: bool = False): + """ + Function parallelizing transfer from DISK to RAM + """ + + def transfer_mol_from_disk_to_ram(idx): + """ + Function transferring single mol from DISK to RAM + """ + data_dict = self.load_graph_from_index(idx) + mol_in_ram = { + "features": data_dict["graph_with_features"], + "labels": data_dict["labels"], + } + + return mol_in_ram + + if parallel_with_batches and self.featurization_batch_size: + data_in_ram = parallelized_with_batches( + transfer_mol_from_disk_to_ram, + range(self.dataset_length), + batch_size=self.featurization_batch_size, + n_jobs=0, + backend=self.backend, + progress=self.progress, + tqdm_kwargs={"desc": "Transfer from DISK to RAM"}, + ) + else: + data_in_ram = parallelized( + transfer_mol_from_disk_to_ram, + range(self.dataset_length), + n_jobs=0, + backend=self.backend, + progress=self.progress, + tqdm_kwargs={"desc": "Transfer from DISK to RAM"}, + ) + + self.features = [sample["features"] for sample in data_in_ram] + self.labels = [sample["labels"] for sample in data_in_ram] def save_metadata(self, directory: str): """ @@ -261,6 +304,14 @@ def _load_metadata(self): for attr, value in attrs.items(): setattr(self, attr, value) + if self.save_smiles_and_ids: + if self.smiles is None or self.mol_ids is None: + logger.warning( + f"Argument `save_smiles_and_ids` is set to {self.save_smiles_and_ids} but metadata in the cache at {self.data_path} does not contain smiles and mol_ids. " + f"This may be because `Datamodule.prepare_data(save_smiles_and_ids=False)` was run followed by `Datamodule.setup(save_smiles_and_ids=True)`. " + f"When loading from cached files, the `save_smiles_and_ids` argument of `Datamodule.setup()` is superseeded by the `Datamodule.prepare_data()`. " + ) + def __len__(self): r""" Returns the number of molecules @@ -377,7 +428,7 @@ def __getitem__(self, idx): A dictionary containing the data for the specified index with keys "mol_ids", "smiles", "labels", and "features" """ datum = {} - if self.load_from_file: + if self.dataloading_from == "disk": data_dict = self.load_graph_from_index(idx) datum["features"] = data_dict["graph_with_features"] datum["labels"] = data_dict["labels"] diff --git a/profiling/configs_profiling.yaml b/profiling/configs_profiling.yaml index ba72c3b64..0ff4f6c94 100644 --- a/profiling/configs_profiling.yaml +++ b/profiling/configs_profiling.yaml @@ -6,7 +6,7 @@ datamodule: module_type: "DGLFromSmilesDataModule" args: df_path: https://storage.googleapis.com/graphium-public/datasets/graphium-zinc-bench-gnn/smiles_score.csv.gz - cache_data_path: null # graphium/data/cache/ZINC_bench_gnn/smiles_score.cache + processed_graph_data_path: null label_cols: ['score'] smiles_col: SMILES diff --git a/profiling/profile_predictor.py b/profiling/profile_predictor.py index be9810d00..80ad284d4 100644 --- a/profiling/profile_predictor.py +++ b/profiling/profile_predictor.py @@ -20,7 +20,9 @@ def main(): with fsspec.open(CONFIG_PATH, "r") as f: cfg = yaml.safe_load(f) - cfg["datamodule"]["args"]["cache_data_path"] = "graphium/data/cache/profiling/predictor_data.cache" + cfg["datamodule"]["args"][ + "processed_graph_data_path" + ] = "graphium/data/cache/profiling/predictor_data.cache" # cfg["datamodule"]["args"]["df_path"] = DATA_PATH cfg["trainer"]["trainer"]["max_epochs"] = 5 cfg["trainer"]["trainer"]["min_epochs"] = 5 diff --git a/pyproject.toml b/pyproject.toml index 9e55eb5f9..20cfa9792 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -65,6 +65,7 @@ dependencies = [ [project.scripts] graphium = "graphium.cli.main:main_cli" graphium-train = "graphium.cli.train_finetune:cli" + graphium-prepare-data = "graphium.cli.prepare_data:cli" [project.urls] Website = "https://graphium.datamol.io/" diff --git a/tests/config_test_ipu_dataloader_multitask.yaml b/tests/config_test_ipu_dataloader_multitask.yaml index 55d177622..8b8fbf417 100644 --- a/tests/config_test_ipu_dataloader_multitask.yaml +++ b/tests/config_test_ipu_dataloader_multitask.yaml @@ -130,7 +130,6 @@ datamodule: pos_type: rw_return_probs ksteps: 16 - # cache_data_path: . num_workers: -1 # -1 to use all persistent_workers: False # if use persistent worker at the start of each epoch. # Using persistent_workers false might make the start of each epoch very long. diff --git a/tests/data/config_micro_ZINC.yaml b/tests/data/config_micro_ZINC.yaml index e8b1a2b92..88fc4a841 100644 --- a/tests/data/config_micro_ZINC.yaml +++ b/tests/data/config_micro_ZINC.yaml @@ -6,7 +6,7 @@ datamodule: module_type: "DGLFromSmilesDataModule" args: df_path: graphium/data/micro_ZINC/micro_ZINC.csv - cache_data_path: graphium/data/cache/micro_ZINC/full.cache + processed_graph_data_path: graphium/data/cache/micro_ZINC/ label_cols: ['score'] smiles_col: SMILES diff --git a/tests/test_datamodule.py b/tests/test_datamodule.py index 1cd09d036..2bc89200c 100644 --- a/tests/test_datamodule.py +++ b/tests/test_datamodule.py @@ -31,25 +31,36 @@ def test_ogb_datamodule(self): task_specific_args = {} task_specific_args["task_1"] = {"task_level": "graph", "dataset_name": dataset_name} dm_args = {} - dm_args["cache_data_path"] = None + dm_args["processed_graph_data_path"] = None dm_args["featurization"] = featurization_args dm_args["batch_size_training"] = 16 dm_args["batch_size_inference"] = 16 dm_args["num_workers"] = 0 dm_args["pin_memory"] = True - dm_args["featurization_n_jobs"] = 2 + dm_args["featurization_n_jobs"] = 0 dm_args["featurization_progress"] = True dm_args["featurization_backend"] = "loky" dm_args["featurization_batch_size"] = 50 ds = GraphOGBDataModule(task_specific_args, **dm_args) - ds.prepare_data() + ds.prepare_data(save_smiles_and_ids=False) # Check the keys in the dataset ds.setup(save_smiles_and_ids=False) assert set(ds.train_ds[0].keys()) == {"features", "labels"} + # Delete the cache if already exist + if exists(TEMP_CACHE_DATA_PATH): + rm(TEMP_CACHE_DATA_PATH, recursive=True) + + # Reset the datamodule + ds._data_is_prepared = False + ds._data_is_cached = False + + ds.prepare_data(save_smiles_and_ids=True) + + # Check the keys in the dataset ds.setup(save_smiles_and_ids=True) assert set(ds.train_ds[0].keys()) == {"smiles", "mol_ids", "features", "labels"} @@ -163,7 +174,6 @@ def test_caching(self): featurization_args = {} featurization_args["atom_property_list_float"] = [] # ["weight", "valence"] featurization_args["atom_property_list_onehot"] = ["atomic-number", "degree"] - # featurization_args["conformer_property_list"] = ["positions_3d"] featurization_args["edge_property_list"] = ["bond-type-onehot"] featurization_args["add_self_loop"] = False featurization_args["use_bonds_weights"] = False @@ -178,7 +188,7 @@ def test_caching(self): dm_args["batch_size_inference"] = 16 dm_args["num_workers"] = 0 dm_args["pin_memory"] = True - dm_args["featurization_n_jobs"] = 2 + dm_args["featurization_n_jobs"] = 0 dm_args["featurization_progress"] = True dm_args["featurization_backend"] = "loky" dm_args["featurization_batch_size"] = 50 @@ -189,24 +199,113 @@ def test_caching(self): # Prepare the data. It should create the cache there assert not exists(TEMP_CACHE_DATA_PATH) - ds = GraphOGBDataModule(task_specific_args, cache_data_path=TEMP_CACHE_DATA_PATH, **dm_args) - assert not ds.load_data_from_cache(verbose=False) - ds.prepare_data() + ds = GraphOGBDataModule(task_specific_args, processed_graph_data_path=TEMP_CACHE_DATA_PATH, **dm_args) + # assert not ds.load_data_from_cache(verbose=False) + ds.prepare_data(save_smiles_and_ids=False) # Check the keys in the dataset ds.setup(save_smiles_and_ids=False) assert set(ds.train_ds[0].keys()) == {"features", "labels"} - ds.setup(save_smiles_and_ids=True) - assert set(ds.train_ds[0].keys()) == {"smiles", "mol_ids", "features", "labels"} + # ds_batch = next(iter(ds.train_dataloader())) + train_loader = ds.get_dataloader(ds.train_ds, shuffle=False, stage="train") + batch = next(iter(train_loader)) + + # Test loading cached data + assert exists(TEMP_CACHE_DATA_PATH) + + cached_ds_from_ram = GraphOGBDataModule( + task_specific_args, + processed_graph_data_path=TEMP_CACHE_DATA_PATH, + dataloading_from="ram", + **dm_args, + ) + cached_ds_from_ram.prepare_data() + cached_ds_from_ram.setup() + cached_train_loader_from_ram = cached_ds_from_ram.get_dataloader( + cached_ds_from_ram.train_ds, shuffle=False, stage="train" + ) + batch_from_ram = next(iter(cached_train_loader_from_ram)) + + cached_ds_from_disk = GraphOGBDataModule( + task_specific_args, + processed_graph_data_path=TEMP_CACHE_DATA_PATH, + dataloading_from="disk", + **dm_args, + ) + cached_ds_from_disk.prepare_data() + cached_ds_from_disk.setup() + cached_train_loader_from_disk = cached_ds_from_disk.get_dataloader( + cached_ds_from_disk.train_ds, shuffle=False, stage="train" + ) + batch_from_disk = next(iter(cached_train_loader_from_disk)) + + # Features are the same + np.testing.assert_array_almost_equal( + batch["features"].edge_index, batch_from_ram["features"].edge_index + ) + np.testing.assert_array_almost_equal( + batch["features"].edge_index, batch_from_disk["features"].edge_index + ) + + assert batch["features"].num_nodes == batch_from_ram["features"].num_nodes + assert batch["features"].num_nodes == batch_from_disk["features"].num_nodes + + np.testing.assert_array_almost_equal( + batch["features"].edge_weight, batch_from_ram["features"].edge_weight + ) + np.testing.assert_array_almost_equal( + batch["features"].edge_weight, batch_from_disk["features"].edge_weight + ) + + np.testing.assert_array_almost_equal(batch["features"].feat, batch_from_ram["features"].feat) + np.testing.assert_array_almost_equal(batch["features"].feat, batch_from_disk["features"].feat) + + np.testing.assert_array_almost_equal( + batch["features"].edge_feat, batch_from_ram["features"].edge_feat + ) + np.testing.assert_array_almost_equal( + batch["features"].edge_feat, batch_from_disk["features"].edge_feat + ) + + np.testing.assert_array_almost_equal(batch["features"].batch, batch_from_ram["features"].batch) + np.testing.assert_array_almost_equal(batch["features"].batch, batch_from_disk["features"].batch) + + np.testing.assert_array_almost_equal(batch["features"].ptr, batch_from_ram["features"].ptr) + np.testing.assert_array_almost_equal(batch["features"].ptr, batch_from_disk["features"].ptr) + + # Labels are the same + np.testing.assert_array_almost_equal( + batch["labels"].graph_task_1, batch_from_ram["labels"].graph_task_1 + ) + np.testing.assert_array_almost_equal( + batch["labels"].graph_task_1, batch_from_disk["labels"].graph_task_1 + ) + + np.testing.assert_array_almost_equal(batch["labels"].x, batch_from_ram["labels"].x) + np.testing.assert_array_almost_equal(batch["labels"].x, batch_from_disk["labels"].x) - # Make sure that the cache is created - full_cache_path = ds.get_data_cache_fullname(compress=False) - assert exists(full_cache_path) - assert get_size(full_cache_path) > 10000 + np.testing.assert_array_almost_equal(batch["labels"].edge_index, batch_from_ram["labels"].edge_index) + np.testing.assert_array_almost_equal(batch["labels"].edge_index, batch_from_disk["labels"].edge_index) - # Check that the data is loaded correctly from cache - assert ds.load_data_from_cache(verbose=False) + np.testing.assert_array_almost_equal(batch["labels"].batch, batch_from_ram["labels"].batch) + np.testing.assert_array_almost_equal(batch["labels"].batch, batch_from_disk["labels"].batch) + + np.testing.assert_array_almost_equal(batch["labels"].ptr, batch_from_ram["labels"].ptr) + np.testing.assert_array_almost_equal(batch["labels"].ptr, batch_from_disk["labels"].ptr) + + # Delete the cache if already exist + if exists(TEMP_CACHE_DATA_PATH): + rm(TEMP_CACHE_DATA_PATH, recursive=True) + + # Reset the datamodule + ds._data_is_prepared = False + ds._data_is_cached = False + + ds.prepare_data(save_smiles_and_ids=True) + + ds.setup(save_smiles_and_ids=True) + assert set(ds.train_ds[0].keys()) == {"smiles", "mol_ids", "features", "labels"} # test module assert ds.num_edge_feats == 5 @@ -219,6 +318,10 @@ def test_caching(self): assert len(batch["labels"]["graph_task_1"]) == 16 assert len(batch["mol_ids"]) == 16 + # Delete the cache if already exist + if exists(TEMP_CACHE_DATA_PATH): + rm(TEMP_CACHE_DATA_PATH, recursive=True) + def test_datamodule_with_none_molecules(self): # Setup the featurization featurization_args = {} @@ -335,7 +438,7 @@ def test_datamodule_multiple_data_files(self): "task": {"task_level": "graph", "label_cols": ["score"], "smiles_col": "SMILES", **task_kwargs} } - ds = MultitaskFromSmilesDataModule(task_specific_args) + ds = MultitaskFromSmilesDataModule(task_specific_args, featurization_n_jobs=0) ds.prepare_data() ds.setup() @@ -348,7 +451,7 @@ def test_datamodule_multiple_data_files(self): "task": {"task_level": "graph", "label_cols": ["score"], "smiles_col": "SMILES", **task_kwargs} } - ds = MultitaskFromSmilesDataModule(task_specific_args) + ds = MultitaskFromSmilesDataModule(task_specific_args, featurization_n_jobs=0) ds.prepare_data() ds.setup() @@ -361,7 +464,7 @@ def test_datamodule_multiple_data_files(self): "task": {"task_level": "graph", "label_cols": ["score"], "smiles_col": "SMILES", **task_kwargs} } - ds = MultitaskFromSmilesDataModule(task_specific_args) + ds = MultitaskFromSmilesDataModule(task_specific_args, featurization_n_jobs=0) ds.prepare_data() ds.setup() @@ -374,7 +477,7 @@ def test_datamodule_multiple_data_files(self): "task": {"task_level": "graph", "label_cols": ["score"], "smiles_col": "SMILES", **task_kwargs} } - ds = MultitaskFromSmilesDataModule(task_specific_args) + ds = MultitaskFromSmilesDataModule(task_specific_args, featurization_n_jobs=0) ds.prepare_data() ds.setup() diff --git a/tests/test_multitask_datamodule.py b/tests/test_multitask_datamodule.py index 796335964..d74fc77ec 100644 --- a/tests/test_multitask_datamodule.py +++ b/tests/test_multitask_datamodule.py @@ -100,7 +100,7 @@ def test_multitask_fromsmiles_dm( dm_args["featurization_backend"] = "loky" dm_args["num_workers"] = 0 dm_args["pin_memory"] = True - dm_args["cache_data_path"] = None + dm_args["processed_graph_data_path"] = None dm_args["batch_size_training"] = 16 dm_args["batch_size_inference"] = 16