From cc27c8ffc012c925600a6cabb61af3165e8bae34 Mon Sep 17 00:00:00 2001 From: DomInvivo Date: Thu, 10 Aug 2023 11:57:36 -0400 Subject: [PATCH 01/14] First pass to change the graph caching logic --- graphium/data/datamodule.py | 142 +++++++++++------------------ graphium/data/dataset.py | 15 +-- profiling/profile_predictor.py | 2 +- tests/test_datamodule.py | 4 +- tests/test_multitask_datamodule.py | 2 +- 5 files changed, 67 insertions(+), 98 deletions(-) diff --git a/graphium/data/datamodule.py b/graphium/data/datamodule.py index b85fd664c..d81249bbf 100644 --- a/graphium/data/datamodule.py +++ b/graphium/data/datamodule.py @@ -770,8 +770,8 @@ class MultitaskFromSmilesDataModule(BaseDataModule, IPUDataModuleModifier): def __init__( self, task_specific_args: Union[DatasetProcessingParams, Dict[str, Any]], - cache_data_path: Optional[Union[str, os.PathLike]] = None, processed_graph_data_path: Optional[Union[str, os.PathLike]] = None, + dataloading_from: str = "ram", featurization: Optional[Union[Dict[str, Any], omegaconf.DictConfig]] = None, batch_size_training: int = 16, batch_size_inference: int = 16, @@ -835,8 +835,11 @@ def __init__( task_splits_path: (value) A path a CSV file containing indices for the splits. The file must contains 3 columns "train", "val" and "test". It takes precedence over `split_val` and `split_test`. - cache_data_path: path where to save or reload the cached data. The path can be - remote (S3, GS, etc). + processed_graph_data_path: path where to save or reload the cached data. Can be used + to avoid recomputing the featurization, or for dataloading from disk with the option `dataloader_from="disk"`. + dataloading_from: Whether to load the data from RAM or from disk. If set to "disk", the data + must have been previously cached with `processed_graph_data_path` set. If set to "ram", the data + will be loaded in RAM and the `processed_graph_data_path` will be ignored. featurization: args to apply to the SMILES to Graph featurizer. batch_size_training: batch size for training and val dataset. batch_size_inference: batch size for test dataset. @@ -909,10 +912,7 @@ def __init__( self.val_ds = None self.test_ds = None - self.cache_data_path = cache_data_path - self.processed_graph_data_path = processed_graph_data_path - - self.load_from_file = processed_graph_data_path is not None + self._parse_caching_args(processed_graph_data_path, dataloading_from) self.task_norms = {} @@ -932,6 +932,27 @@ def __init__( ) self.data_hash = self.get_data_hash() + def _parse_caching_args(self, processed_graph_data_path, dataloading_from): + """ + Parse the caching arguments, and raise errors if the arguments are invalid. + """ + + # Whether to load the data from RAM or from disk + dataloading_from = dataloading_from.lower() + if dataloading_from not in ["disk", "ram"]: + raise ValueError( + f"`dataloading_from` should be either 'disk' or 'ram', Provided: `{dataloading_from}`" + ) + + # If loading from disk, the path to the cached data must be provided + if dataloading_from == "disk" and processed_graph_data_path is None: + raise ValueError( + "When `dataloading_from` is 'disk', `processed_graph_data_path` must be provided." + ) + + self.processed_graph_data_path = processed_graph_data_path + self.dataloading_from = dataloading_from + def _get_task_key(self, task_level: str, task: str): task_prefix = f"{task_level}_" if not task.startswith(task_prefix): @@ -976,23 +997,12 @@ def has_atoms_after_h_removal(smiles): logger.info("Data is already prepared. Skipping the preparation") return - if self.load_from_file: + if self.dataloading_from == "disk": if self._ready_to_load_all_from_file(): self.get_label_statistics(self.processed_graph_data_path, self.data_hash, dataset=None) self._data_is_prepared = True return - else: - # If a path for data caching is provided, try to load from the path. - # If successful, skip the data preparation. - # For next task: load the single graph files for train, val and test data - cache_data_exists = self.load_data_from_cache() - # need to check if cache exist properly - if cache_data_exists: - self.get_label_statistics(self.cache_data_path, self.data_hash, dataset=None) - self._data_is_prepared = True - return - """Load all single-task dataframes.""" task_df = {} for task, args in self.task_dataset_processing_params.items(): @@ -1160,13 +1170,9 @@ def has_atoms_after_h_removal(smiles): self.single_task_datasets, self.task_train_indices, self.task_val_indices, self.task_test_indices ) - if self.load_from_file: + if self.processed_graph_data_path is not None: self._save_data_to_files() - # When a cache path is provided but no cache is found, save to cache - elif (self.cache_data_path is not None) and (not cache_data_exists): - self.save_data_to_cache() - self._data_is_prepared = True def setup( @@ -1185,7 +1191,7 @@ def setup( labels_size = {} labels_dtype = {} if stage == "fit" or stage is None: - if self.load_from_file: + if self.processed_graph_data_path is not None: processed_train_data_path = self._path_to_load_from_file("train") assert self._data_ready_at_path( processed_train_data_path @@ -1210,7 +1216,7 @@ def setup( labels_dtype.update(self.val_ds.labels_dtype) if stage == "test" or stage is None: - if self.load_from_file: + if self.dataloading_from == "disk": processed_test_data_path = self._path_to_load_from_file("test") assert self._data_ready_at_path( processed_test_data_path @@ -1237,7 +1243,7 @@ def _make_multitask_dataset( self, stage: Literal["train", "val", "test"], save_smiles_and_ids: bool, - load_from_file: Optional[bool] = None, + processed_graph_data_path: Optional[str] = None, ) -> Datasets.MultitaskDataset: """ Create a MultitaskDataset for the given stage using single task datasets @@ -1246,8 +1252,7 @@ def _make_multitask_dataset( Parameters: stage: Stage to create multitask dataset for save_smiles_and_ids: Whether to save SMILES strings and unique IDs - data_path: path to load from if loading from file - load_from_file: whether to load from file. If `None`, defers to `self.load_from_file` + processed_graph_data_path: path to save and load processed graph data from """ allowed_stages = ["train", "val", "test"] @@ -1265,12 +1270,12 @@ def _make_multitask_dataset( else: raise ValueError(f"Unknown stage {stage}") - if load_from_file is None: - load_from_file = self.load_from_file + if processed_graph_data_path is None: + processed_graph_data_path = self.processed_graph_data_path # assert singletask_datasets is not None, "Single task datasets must exist to make multitask dataset" if singletask_datasets is None: - assert load_from_file + assert processed_graph_data_path is not None assert self._data_ready_at_path( self._path_to_load_from_file(stage) ), "Trying to create multitask dataset without single-task datasets but data not ready" @@ -1286,8 +1291,8 @@ def _make_multitask_dataset( progress=self.featurization_progress, about=about, save_smiles_and_ids=save_smiles_and_ids, - data_path=self._path_to_load_from_file(stage) if load_from_file else None, - load_from_file=load_from_file, + data_path=self._path_to_load_from_file(stage) if processed_graph_data_path else None, + processed_graph_data_path=processed_graph_data_path, files_ready=files_ready, ) # type: ignore @@ -1296,7 +1301,7 @@ def _make_multitask_dataset( self.get_label_statistics( self.processed_graph_data_path, self.data_hash, multitask_dataset, train=True ) - if not load_from_file: + if self.dataloading_from == "ram": self.normalize_label(multitask_dataset, stage) return multitask_dataset @@ -2004,60 +2009,18 @@ def get_data_cache_fullname(self, compress: bool = False) -> str: Returns: full path to the data cache file """ - if self.cache_data_path is None: + if self.processed_graph_data_path is None: return ext = ".datacache" if compress: ext += ".gz" - data_cache_fullname = fs.join(self.cache_data_path, self.data_hash + ext) + data_cache_fullname = fs.join(self.processed_graph_data_path, self.data_hash + ext) return data_cache_fullname - def save_data_to_cache(self, verbose: bool = True, compress: bool = False) -> None: - """ - Save the datasets from cache. First create a hash for the dataset, use it to - generate a file name. Then save to the path given by `self.cache_data_path`. - - Parameters: - verbose: Whether to print the progress - compress: Whether to compress the data - - """ - full_cache_data_path = self.get_data_cache_fullname(compress=compress) - if full_cache_data_path is None: - logger.info("No cache data path specified. Skipping saving the data to cache.") - return - - save_params = { - "single_task_datasets": self.single_task_datasets, - "task_train_indices": self.task_train_indices, - "task_val_indices": self.task_val_indices, - "task_test_indices": self.task_test_indices, - } - - fs.mkdir(self.cache_data_path) - with fsspec.open(full_cache_data_path, mode="wb", compression="infer") as file: - if verbose: - logger.info(f"Saving the data to cache at path:\n`{full_cache_data_path}`") - now = time.time() - torch.save(save_params, file) - elapsed = round(time.time() - now) - if verbose: - logger.info( - f"Successfully saved the data to cache in {elapsed}s at path: `{full_cache_data_path}`" - ) - - # At the moment, we need to merge the `SingleTaskDataset`'s into `MultitaskDataset`s in order to save label stats - # This is because the combined labels need to be stored together. We can investigate not doing this if this is a problem - temp_train_dataset = self._make_multitask_dataset( - stage="train", save_smiles_and_ids=False, load_from_file=False - ) - - self.get_label_statistics(self.cache_data_path, self.data_hash, temp_train_dataset, train=True) - def load_data_from_cache(self, verbose: bool = True, compress: bool = False) -> bool: """ Load the datasets from cache. First create a hash for the dataset, and verify if that - hash is available at the path given by `self.cache_data_path`. + hash is available at the path given by `self.processed_graph_data_path`. Parameters: verbose: Whether to print the progress @@ -2193,7 +2156,8 @@ class GraphOGBDataModule(MultitaskFromSmilesDataModule): def __init__( self, task_specific_args: Dict[str, Union[DatasetProcessingParams, Dict[str, Any]]], - cache_data_path: Optional[Union[str, os.PathLike]] = None, + processed_graph_data_path: Optional[Union[str, os.PathLike]] = None, + dataloading_from: str = "ram", featurization: Optional[Union[Dict[str, Any], omegaconf.DictConfig]] = None, batch_size_training: int = 16, batch_size_inference: int = 16, @@ -2220,8 +2184,9 @@ def __init__( "ogbg-molhiv", "ogbg-molpcba", "ogbg-moltox21", "ogbg-molfreesolv". - "sample_size": The number of molecules to sample from the dataset. Default=None, meaning that all molecules will be considered. - cache_data_path: path where to save or reload the cached data. The path can be - remote (S3, GS, etc). + processed_graph_data_path: Path to the processed graph data. If None, the data will be + downloaded from the OGB website. + dataloading_from: Whether to load the data from RAM or disk. Default is "ram". featurization: args to apply to the SMILES to Graph featurizer. batch_size_training: batch size for training and val dataset. batch_size_inference: batch size for test dataset. @@ -2266,7 +2231,8 @@ def __init__( # Config for datamodule dm_args = {} dm_args["task_specific_args"] = new_task_specific_args - dm_args["cache_data_path"] = cache_data_path + dm_args["processed_graph_data_path"] = processed_graph_data_path + dm_args["dataloader_from"] = dataloading_from dm_args["featurization"] = featurization dm_args["batch_size_training"] = batch_size_training dm_args["batch_size_inference"] = batch_size_inference @@ -2449,7 +2415,8 @@ def __init__( tdc_benchmark_names: Optional[Union[str, List[str]]] = None, tdc_train_val_seed: int = 0, # Inherited arguments from superclass - cache_data_path: Optional[Union[str, os.PathLike]] = None, + processed_graph_data_path: Optional[Union[str, Path]] = None, + dataloading_from: str = "ram", featurization: Optional[Union[Dict[str, Any], omegaconf.DictConfig]] = None, batch_size_training: int = 16, batch_size_inference: int = 16, @@ -2506,8 +2473,9 @@ def __init__( super().__init__( task_specific_args=task_specific_args, - cache_data_path=cache_data_path, featurization=featurization, + processed_graph_data_path=processed_graph_data_path, + dataloading_from=dataloading_from, batch_size_training=batch_size_training, batch_size_inference=batch_size_inference, batch_size_per_pack=batch_size_per_pack, @@ -2591,7 +2559,6 @@ class FakeDataModule(MultitaskFromSmilesDataModule): def __init__( self, task_specific_args: Dict[str, Dict[str, Any]], # TODO: Replace this with DatasetParams - cache_data_path: Optional[Union[str, os.PathLike]] = None, featurization: Optional[Union[Dict[str, Any], omegaconf.DictConfig]] = None, batch_size_training: int = 16, batch_size_inference: int = 16, @@ -2606,7 +2573,6 @@ def __init__( ): super().__init__( task_specific_args=task_specific_args, - cache_data_path=cache_data_path, featurization=featurization, batch_size_training=batch_size_training, batch_size_inference=batch_size_inference, diff --git a/graphium/data/dataset.py b/graphium/data/dataset.py index 180e3275f..c89758d29 100644 --- a/graphium/data/dataset.py +++ b/graphium/data/dataset.py @@ -146,7 +146,7 @@ def __init__( save_smiles_and_ids: bool = False, about: str = "", data_path: Optional[Union[str, os.PathLike]] = None, - load_from_file: bool = False, + dataloading_from: str = "ram", files_ready: bool = False, ): r""" @@ -172,7 +172,7 @@ def __init__( progress: Whether to display the progress bar about: A description of the dataset data_path: The location of the data if saved on disk - load_from_file: Whether to load the data from disk + dataloading_from: Whether to load the data from `"disk"` or `"ram"` files_ready: Whether the files to load from were prepared ahead of time """ super().__init__() @@ -183,10 +183,13 @@ def __init__( self.progress = progress self.about = about self.data_path = data_path - self.load_from_file = load_from_file + self.dataloading_from = dataloading_from if files_ready: - assert load_from_file + if dataloading_from != "disk": + raise ValueError( + "Files are ready to be loaded from disk, but `dataloading_from` is not set to `disk`" + ) self._load_metadata() self.features = None self.labels = None @@ -210,7 +213,7 @@ def __init__( if self.features is not None: self._num_nodes_list = get_num_nodes_per_graph(self.features) self._num_edges_list = get_num_edges_per_graph(self.features) - if self.load_from_file: + if self.dataloading_from == "disk": self.features = None self.labels = None @@ -377,7 +380,7 @@ def __getitem__(self, idx): A dictionary containing the data for the specified index with keys "mol_ids", "smiles", "labels", and "features" """ datum = {} - if self.load_from_file: + if self.dataloading_from == "disk": data_dict = self.load_graph_from_index(idx) datum["features"] = data_dict["graph_with_features"] datum["labels"] = data_dict["labels"] diff --git a/profiling/profile_predictor.py b/profiling/profile_predictor.py index be9810d00..16df450c1 100644 --- a/profiling/profile_predictor.py +++ b/profiling/profile_predictor.py @@ -20,7 +20,7 @@ def main(): with fsspec.open(CONFIG_PATH, "r") as f: cfg = yaml.safe_load(f) - cfg["datamodule"]["args"]["cache_data_path"] = "graphium/data/cache/profiling/predictor_data.cache" + cfg["datamodule"]["args"]["processed_graph_data_path"] = "graphium/data/cache/profiling/predictor_data.cache" # cfg["datamodule"]["args"]["df_path"] = DATA_PATH cfg["trainer"]["trainer"]["max_epochs"] = 5 cfg["trainer"]["trainer"]["min_epochs"] = 5 diff --git a/tests/test_datamodule.py b/tests/test_datamodule.py index 1cd09d036..45cd30106 100644 --- a/tests/test_datamodule.py +++ b/tests/test_datamodule.py @@ -31,7 +31,7 @@ def test_ogb_datamodule(self): task_specific_args = {} task_specific_args["task_1"] = {"task_level": "graph", "dataset_name": dataset_name} dm_args = {} - dm_args["cache_data_path"] = None + dm_args["processed_graph_data_path"] = None dm_args["featurization"] = featurization_args dm_args["batch_size_training"] = 16 dm_args["batch_size_inference"] = 16 @@ -189,7 +189,7 @@ def test_caching(self): # Prepare the data. It should create the cache there assert not exists(TEMP_CACHE_DATA_PATH) - ds = GraphOGBDataModule(task_specific_args, cache_data_path=TEMP_CACHE_DATA_PATH, **dm_args) + ds = GraphOGBDataModule(task_specific_args, processed_graph_data_path=TEMP_CACHE_DATA_PATH, **dm_args) assert not ds.load_data_from_cache(verbose=False) ds.prepare_data() diff --git a/tests/test_multitask_datamodule.py b/tests/test_multitask_datamodule.py index 796335964..d74fc77ec 100644 --- a/tests/test_multitask_datamodule.py +++ b/tests/test_multitask_datamodule.py @@ -100,7 +100,7 @@ def test_multitask_fromsmiles_dm( dm_args["featurization_backend"] = "loky" dm_args["num_workers"] = 0 dm_args["pin_memory"] = True - dm_args["cache_data_path"] = None + dm_args["processed_graph_data_path"] = None dm_args["batch_size_training"] = 16 dm_args["batch_size_inference"] = 16 From bdf14d1802bd0ff7dfa9e6a80f8e2469dd5c6c8a Mon Sep 17 00:00:00 2001 From: DomInvivo Date: Thu, 10 Aug 2023 12:08:55 -0400 Subject: [PATCH 02/14] changing the yaml configs with the new caching logic --- expts/configs/config_gps_10M_pcqm4m.yaml | 1 - expts/configs/config_gps_10M_pcqm4m_mod.yaml | 1 - expts/configs/config_mpnn_10M_b3lyp.yaml | 2 +- expts/configs/config_mpnn_pcqm4m.yaml | 3 +-- expts/neurips2023_configs/base_config/large.yaml | 1 - expts/neurips2023_configs/base_config/small.yaml | 1 - .../baseline/config_small_gcn_baseline.yaml | 1 - .../config_classifigression_l1000.yaml | 1 - expts/neurips2023_configs/config_luis_jama.yaml | 1 - expts/neurips2023_configs/debug/config_debug.yaml | 1 - .../debug/config_large_gcn_debug.yaml | 3 +-- .../debug/config_small_gcn_debug.yaml | 1 - .../single_task_gcn/config_large_gcn_mcf7.yaml | 1 - .../single_task_gcn/config_large_gcn_pcba.yaml | 1 - .../single_task_gcn/config_large_gcn_vcap.yaml | 1 - .../single_task_gin/config_large_gin_g25.yaml | 1 - .../single_task_gin/config_large_gin_mcf7.yaml | 1 - .../single_task_gin/config_large_gin_n4.yaml | 1 - .../single_task_gin/config_large_gin_pcba.yaml | 1 - .../single_task_gin/config_large_gin_pcq.yaml | 1 - .../single_task_gin/config_large_gin_vcap.yaml | 1 - .../single_task_gine/config_large_gine_g25.yaml | 1 - .../single_task_gine/config_large_gine_mcf7.yaml | 1 - .../single_task_gine/config_large_gine_n4.yaml | 1 - .../single_task_gine/config_large_gine_pcba.yaml | 1 - .../single_task_gine/config_large_gine_pcq.yaml | 1 - .../single_task_gine/config_large_gine_vcap.yaml | 1 - .../fake_and_missing_multilevel_multitask_pyg.yaml | 9 ++++----- graphium/config/fake_multilevel_multitask_pyg.yaml | 9 ++++----- graphium/config/zinc_default_multitask_pyg.yaml | 1 - profiling/configs_profiling.yaml | 2 +- tests/config_test_ipu_dataloader_multitask.yaml | 1 - tests/data/config_micro_ZINC.yaml | 2 +- 33 files changed, 13 insertions(+), 43 deletions(-) diff --git a/expts/configs/config_gps_10M_pcqm4m.yaml b/expts/configs/config_gps_10M_pcqm4m.yaml index 0b0dff7dc..10faa3b1e 100644 --- a/expts/configs/config_gps_10M_pcqm4m.yaml +++ b/expts/configs/config_gps_10M_pcqm4m.yaml @@ -112,7 +112,6 @@ datamodule: pos_type: rw_return_probs ksteps: 16 - # cache_data_path: . num_workers: 0 # -1 to use all persistent_workers: False # if use persistent worker at the start of each epoch. # Using persistent_workers false might make the start of each epoch very long. diff --git a/expts/configs/config_gps_10M_pcqm4m_mod.yaml b/expts/configs/config_gps_10M_pcqm4m_mod.yaml index 1c9f6da31..e2cdb44c2 100644 --- a/expts/configs/config_gps_10M_pcqm4m_mod.yaml +++ b/expts/configs/config_gps_10M_pcqm4m_mod.yaml @@ -81,7 +81,6 @@ datamodule: # Data handling-related batch_size_training: 64 batch_size_inference: 16 - # cache_data_path: . num_workers: 0 # -1 to use all persistent_workers: False # if use persistent worker at the start of each epoch. # Using persistent_workers false might make the start of each epoch very long. diff --git a/expts/configs/config_mpnn_10M_b3lyp.yaml b/expts/configs/config_mpnn_10M_b3lyp.yaml index dca4cd540..c385d7689 100644 --- a/expts/configs/config_mpnn_10M_b3lyp.yaml +++ b/expts/configs/config_mpnn_10M_b3lyp.yaml @@ -93,6 +93,7 @@ datamodule: featurization_progress: True featurization_backend: "loky" processed_graph_data_path: "../datacache/b3lyp/" + dataloading_from: ram featurization: # OGB: ['atomic_num', 'degree', 'possible_formal_charge', 'possible_numH' (total-valence), # 'possible_number_radical_e', 'possible_is_aromatic', 'possible_is_in_ring', @@ -123,7 +124,6 @@ datamodule: pos_type: rw_return_probs ksteps: 16 - # cache_data_path: . num_workers: 0 # -1 to use all persistent_workers: False # if use persistent worker at the start of each epoch. # Using persistent_workers false might make the start of each epoch very long. diff --git a/expts/configs/config_mpnn_pcqm4m.yaml b/expts/configs/config_mpnn_pcqm4m.yaml index 4e34f89ea..9735f9555 100644 --- a/expts/configs/config_mpnn_pcqm4m.yaml +++ b/expts/configs/config_mpnn_pcqm4m.yaml @@ -30,8 +30,8 @@ datamodule: featurization_n_jobs: 20 featurization_progress: True featurization_backend: "loky" - cache_data_path: "./datacache" processed_graph_data_path: "graphium/data/PCQM4Mv2/" + dataloading_from: ram featurization: # OGB: ['atomic_num', 'degree', 'possible_formal_charge', 'possible_numH' (total-valence), # 'possible_number_radical_e', 'possible_is_aromatic', 'possible_is_in_ring', @@ -58,7 +58,6 @@ datamodule: # Data handling-related batch_size_training: 64 batch_size_inference: 16 - # cache_data_path: . num_workers: 40 # -1 to use all persistent_workers: False # if use persistent worker at the start of each epoch. # Using persistent_workers false might make the start of each epoch very long. diff --git a/expts/neurips2023_configs/base_config/large.yaml b/expts/neurips2023_configs/base_config/large.yaml index db2b5dbb6..5ba023b3e 100644 --- a/expts/neurips2023_configs/base_config/large.yaml +++ b/expts/neurips2023_configs/base_config/large.yaml @@ -168,7 +168,6 @@ datamodule: pos_type: rw_return_probs ksteps: 16 - # cache_data_path: . num_workers: 32 # -1 to use all persistent_workers: True # if use persistent worker at the start of each epoch. # Using persistent_workers false might make the start of each epoch very long. diff --git a/expts/neurips2023_configs/base_config/small.yaml b/expts/neurips2023_configs/base_config/small.yaml index 2e63477a1..fd7ce3fbe 100644 --- a/expts/neurips2023_configs/base_config/small.yaml +++ b/expts/neurips2023_configs/base_config/small.yaml @@ -132,7 +132,6 @@ datamodule: pos_type: rw_return_probs ksteps: 16 - # cache_data_path: . num_workers: 30 # -1 to use all persistent_workers: False # if use persistent worker at the start of each epoch. # Using persistent_workers false might make the start of each epoch very long. diff --git a/expts/neurips2023_configs/baseline/config_small_gcn_baseline.yaml b/expts/neurips2023_configs/baseline/config_small_gcn_baseline.yaml index 401dcabd6..7b2d2cbdf 100644 --- a/expts/neurips2023_configs/baseline/config_small_gcn_baseline.yaml +++ b/expts/neurips2023_configs/baseline/config_small_gcn_baseline.yaml @@ -131,7 +131,6 @@ datamodule: pos_type: rw_return_probs ksteps: 16 - # cache_data_path: . num_workers: 30 # -1 to use all persistent_workers: False # if use persistent worker at the start of each epoch. # Using persistent_workers false might make the start of each epoch very long. diff --git a/expts/neurips2023_configs/config_classifigression_l1000.yaml b/expts/neurips2023_configs/config_classifigression_l1000.yaml index 37d83736f..48f06d9d1 100644 --- a/expts/neurips2023_configs/config_classifigression_l1000.yaml +++ b/expts/neurips2023_configs/config_classifigression_l1000.yaml @@ -111,7 +111,6 @@ datamodule: pos_type: rw_return_probs ksteps: 16 - # cache_data_path: . num_workers: 5 # -1 to use all persistent_workers: False # if use persistent worker at the start of each epoch. # Using persistent_workers false might make the start of each epoch very long. diff --git a/expts/neurips2023_configs/config_luis_jama.yaml b/expts/neurips2023_configs/config_luis_jama.yaml index 46ec4c4c0..5135c5cae 100644 --- a/expts/neurips2023_configs/config_luis_jama.yaml +++ b/expts/neurips2023_configs/config_luis_jama.yaml @@ -119,7 +119,6 @@ datamodule: pos_type: rw_return_probs ksteps: 16 - # cache_data_path: . num_workers: 4 # -1 to use all persistent_workers: False # if use persistent worker at the start of each epoch. # Using persistent_workers false might make the start of each epoch very long. diff --git a/expts/neurips2023_configs/debug/config_debug.yaml b/expts/neurips2023_configs/debug/config_debug.yaml index a323427e5..3d31e5e8c 100644 --- a/expts/neurips2023_configs/debug/config_debug.yaml +++ b/expts/neurips2023_configs/debug/config_debug.yaml @@ -105,7 +105,6 @@ datamodule: pos_type: rw_return_probs ksteps: 16 - # cache_data_path: . num_workers: 0 # -1 to use all persistent_workers: False # if use persistent worker at the start of each epoch. # Using persistent_workers false might make the start of each epoch very long. diff --git a/expts/neurips2023_configs/debug/config_large_gcn_debug.yaml b/expts/neurips2023_configs/debug/config_large_gcn_debug.yaml index 5fe6d8741..ec05bf6eb 100644 --- a/expts/neurips2023_configs/debug/config_large_gcn_debug.yaml +++ b/expts/neurips2023_configs/debug/config_large_gcn_debug.yaml @@ -166,7 +166,6 @@ datamodule: pos_type: rw_return_probs ksteps: 16 - # cache_data_path: . num_workers: 30 # -1 to use all persistent_workers: False # if use persistent worker at the start of each epoch. # Using persistent_workers false might make the start of each epoch very long. @@ -327,7 +326,7 @@ predictor: l1000_mcf7: [] pcba_1328: [] pcqm4m_g25: [] - pcqm4m_n4: [] + pcqm4m_n4: [] loss_fun: l1000_vcap: name: hybrid_ce_ipu diff --git a/expts/neurips2023_configs/debug/config_small_gcn_debug.yaml b/expts/neurips2023_configs/debug/config_small_gcn_debug.yaml index 717ae0675..26b50756f 100644 --- a/expts/neurips2023_configs/debug/config_small_gcn_debug.yaml +++ b/expts/neurips2023_configs/debug/config_small_gcn_debug.yaml @@ -119,7 +119,6 @@ datamodule: pos_type: rw_return_probs ksteps: 16 - # cache_data_path: . num_workers: 30 # -1 to use all persistent_workers: False # if use persistent worker at the start of each epoch. # Using persistent_workers false might make the start of each epoch very long. diff --git a/expts/neurips2023_configs/single_task_gcn/config_large_gcn_mcf7.yaml b/expts/neurips2023_configs/single_task_gcn/config_large_gcn_mcf7.yaml index f73d4b08c..e05d1be8d 100644 --- a/expts/neurips2023_configs/single_task_gcn/config_large_gcn_mcf7.yaml +++ b/expts/neurips2023_configs/single_task_gcn/config_large_gcn_mcf7.yaml @@ -100,7 +100,6 @@ datamodule: pos_type: rw_return_probs ksteps: 16 - # cache_data_path: . num_workers: 30 # -1 to use all persistent_workers: False # if use persistent worker at the start of each epoch. # Using persistent_workers false might make the start of each epoch very long. diff --git a/expts/neurips2023_configs/single_task_gcn/config_large_gcn_pcba.yaml b/expts/neurips2023_configs/single_task_gcn/config_large_gcn_pcba.yaml index 3985f26a7..cf924850e 100644 --- a/expts/neurips2023_configs/single_task_gcn/config_large_gcn_pcba.yaml +++ b/expts/neurips2023_configs/single_task_gcn/config_large_gcn_pcba.yaml @@ -100,7 +100,6 @@ datamodule: pos_type: rw_return_probs ksteps: 16 - # cache_data_path: . num_workers: 30 # -1 to use all persistent_workers: False # if use persistent worker at the start of each epoch. # Using persistent_workers false might make the start of each epoch very long. diff --git a/expts/neurips2023_configs/single_task_gcn/config_large_gcn_vcap.yaml b/expts/neurips2023_configs/single_task_gcn/config_large_gcn_vcap.yaml index dad3893a9..f1c9bcfd4 100644 --- a/expts/neurips2023_configs/single_task_gcn/config_large_gcn_vcap.yaml +++ b/expts/neurips2023_configs/single_task_gcn/config_large_gcn_vcap.yaml @@ -100,7 +100,6 @@ datamodule: pos_type: rw_return_probs ksteps: 16 - # cache_data_path: . num_workers: 30 # -1 to use all persistent_workers: False # if use persistent worker at the start of each epoch. # Using persistent_workers false might make the start of each epoch very long. diff --git a/expts/neurips2023_configs/single_task_gin/config_large_gin_g25.yaml b/expts/neurips2023_configs/single_task_gin/config_large_gin_g25.yaml index 20c6aaa37..01988e527 100644 --- a/expts/neurips2023_configs/single_task_gin/config_large_gin_g25.yaml +++ b/expts/neurips2023_configs/single_task_gin/config_large_gin_g25.yaml @@ -103,7 +103,6 @@ datamodule: pos_type: rw_return_probs ksteps: 16 - # cache_data_path: . num_workers: 30 # -1 to use all persistent_workers: False # if use persistent worker at the start of each epoch. # Using persistent_workers false might make the start of each epoch very long. diff --git a/expts/neurips2023_configs/single_task_gin/config_large_gin_mcf7.yaml b/expts/neurips2023_configs/single_task_gin/config_large_gin_mcf7.yaml index 12974d9e4..fdeb4b399 100644 --- a/expts/neurips2023_configs/single_task_gin/config_large_gin_mcf7.yaml +++ b/expts/neurips2023_configs/single_task_gin/config_large_gin_mcf7.yaml @@ -100,7 +100,6 @@ datamodule: pos_type: rw_return_probs ksteps: 16 - # cache_data_path: . num_workers: 30 # -1 to use all persistent_workers: False # if use persistent worker at the start of each epoch. # Using persistent_workers false might make the start of each epoch very long. diff --git a/expts/neurips2023_configs/single_task_gin/config_large_gin_n4.yaml b/expts/neurips2023_configs/single_task_gin/config_large_gin_n4.yaml index 72320f137..5920a80f6 100644 --- a/expts/neurips2023_configs/single_task_gin/config_large_gin_n4.yaml +++ b/expts/neurips2023_configs/single_task_gin/config_large_gin_n4.yaml @@ -104,7 +104,6 @@ datamodule: pos_type: rw_return_probs ksteps: 16 - # cache_data_path: . num_workers: 30 # -1 to use all persistent_workers: False # if use persistent worker at the start of each epoch. # Using persistent_workers false might make the start of each epoch very long. diff --git a/expts/neurips2023_configs/single_task_gin/config_large_gin_pcba.yaml b/expts/neurips2023_configs/single_task_gin/config_large_gin_pcba.yaml index 1d9601ee1..de2f7fbc4 100644 --- a/expts/neurips2023_configs/single_task_gin/config_large_gin_pcba.yaml +++ b/expts/neurips2023_configs/single_task_gin/config_large_gin_pcba.yaml @@ -100,7 +100,6 @@ datamodule: pos_type: rw_return_probs ksteps: 16 - # cache_data_path: . num_workers: 30 # -1 to use all persistent_workers: False # if use persistent worker at the start of each epoch. # Using persistent_workers false might make the start of each epoch very long. diff --git a/expts/neurips2023_configs/single_task_gin/config_large_gin_pcq.yaml b/expts/neurips2023_configs/single_task_gin/config_large_gin_pcq.yaml index 85fce8e13..ca820e86b 100644 --- a/expts/neurips2023_configs/single_task_gin/config_large_gin_pcq.yaml +++ b/expts/neurips2023_configs/single_task_gin/config_large_gin_pcq.yaml @@ -118,7 +118,6 @@ datamodule: pos_type: rw_return_probs ksteps: 16 - # cache_data_path: . num_workers: 30 # -1 to use all persistent_workers: False # if use persistent worker at the start of each epoch. # Using persistent_workers false might make the start of each epoch very long. diff --git a/expts/neurips2023_configs/single_task_gin/config_large_gin_vcap.yaml b/expts/neurips2023_configs/single_task_gin/config_large_gin_vcap.yaml index c52a041f1..c21b765b3 100644 --- a/expts/neurips2023_configs/single_task_gin/config_large_gin_vcap.yaml +++ b/expts/neurips2023_configs/single_task_gin/config_large_gin_vcap.yaml @@ -100,7 +100,6 @@ datamodule: pos_type: rw_return_probs ksteps: 16 - # cache_data_path: . num_workers: 30 # -1 to use all persistent_workers: False # if use persistent worker at the start of each epoch. # Using persistent_workers false might make the start of each epoch very long. diff --git a/expts/neurips2023_configs/single_task_gine/config_large_gine_g25.yaml b/expts/neurips2023_configs/single_task_gine/config_large_gine_g25.yaml index 4ab892b00..b88314797 100644 --- a/expts/neurips2023_configs/single_task_gine/config_large_gine_g25.yaml +++ b/expts/neurips2023_configs/single_task_gine/config_large_gine_g25.yaml @@ -103,7 +103,6 @@ datamodule: pos_type: rw_return_probs ksteps: 16 - # cache_data_path: . num_workers: 30 # -1 to use all persistent_workers: False # if use persistent worker at the start of each epoch. # Using persistent_workers false might make the start of each epoch very long. diff --git a/expts/neurips2023_configs/single_task_gine/config_large_gine_mcf7.yaml b/expts/neurips2023_configs/single_task_gine/config_large_gine_mcf7.yaml index 8605121f1..b96fc8daf 100644 --- a/expts/neurips2023_configs/single_task_gine/config_large_gine_mcf7.yaml +++ b/expts/neurips2023_configs/single_task_gine/config_large_gine_mcf7.yaml @@ -100,7 +100,6 @@ datamodule: pos_type: rw_return_probs ksteps: 16 - # cache_data_path: . num_workers: 30 # -1 to use all persistent_workers: False # if use persistent worker at the start of each epoch. # Using persistent_workers false might make the start of each epoch very long. diff --git a/expts/neurips2023_configs/single_task_gine/config_large_gine_n4.yaml b/expts/neurips2023_configs/single_task_gine/config_large_gine_n4.yaml index ee89b6012..e98ae03da 100644 --- a/expts/neurips2023_configs/single_task_gine/config_large_gine_n4.yaml +++ b/expts/neurips2023_configs/single_task_gine/config_large_gine_n4.yaml @@ -104,7 +104,6 @@ datamodule: pos_type: rw_return_probs ksteps: 16 - # cache_data_path: . num_workers: 30 # -1 to use all persistent_workers: False # if use persistent worker at the start of each epoch. # Using persistent_workers false might make the start of each epoch very long. diff --git a/expts/neurips2023_configs/single_task_gine/config_large_gine_pcba.yaml b/expts/neurips2023_configs/single_task_gine/config_large_gine_pcba.yaml index 42ac474e9..427f7ca0f 100644 --- a/expts/neurips2023_configs/single_task_gine/config_large_gine_pcba.yaml +++ b/expts/neurips2023_configs/single_task_gine/config_large_gine_pcba.yaml @@ -100,7 +100,6 @@ datamodule: pos_type: rw_return_probs ksteps: 16 - # cache_data_path: . num_workers: 30 # -1 to use all persistent_workers: False # if use persistent worker at the start of each epoch. # Using persistent_workers false might make the start of each epoch very long. diff --git a/expts/neurips2023_configs/single_task_gine/config_large_gine_pcq.yaml b/expts/neurips2023_configs/single_task_gine/config_large_gine_pcq.yaml index 84bd5c66b..07fc6d009 100644 --- a/expts/neurips2023_configs/single_task_gine/config_large_gine_pcq.yaml +++ b/expts/neurips2023_configs/single_task_gine/config_large_gine_pcq.yaml @@ -118,7 +118,6 @@ datamodule: pos_type: rw_return_probs ksteps: 16 - # cache_data_path: . num_workers: 30 # -1 to use all persistent_workers: False # if use persistent worker at the start of each epoch. # Using persistent_workers false might make the start of each epoch very long. diff --git a/expts/neurips2023_configs/single_task_gine/config_large_gine_vcap.yaml b/expts/neurips2023_configs/single_task_gine/config_large_gine_vcap.yaml index 09d23fb92..b63263b3d 100644 --- a/expts/neurips2023_configs/single_task_gine/config_large_gine_vcap.yaml +++ b/expts/neurips2023_configs/single_task_gine/config_large_gine_vcap.yaml @@ -100,7 +100,6 @@ datamodule: pos_type: rw_return_probs ksteps: 16 - # cache_data_path: . num_workers: 30 # -1 to use all persistent_workers: False # if use persistent worker at the start of each epoch. # Using persistent_workers false might make the start of each epoch very long. diff --git a/graphium/config/fake_and_missing_multilevel_multitask_pyg.yaml b/graphium/config/fake_and_missing_multilevel_multitask_pyg.yaml index 48d55a501..044a0129c 100644 --- a/graphium/config/fake_and_missing_multilevel_multitask_pyg.yaml +++ b/graphium/config/fake_and_missing_multilevel_multitask_pyg.yaml @@ -58,7 +58,6 @@ datamodule: # Data handling-related batch_size_training: 16 batch_size_inference: 16 - # cache_data_path: null architecture: # The parameters for the full graph network are taken from `config_micro_ZINC.yaml` model_type: FullGraphMultiTaskNetwork @@ -111,7 +110,7 @@ architecture: # The parameters for the full graph network are taken from `co dropout: *dropout normalization: *normalization last_normalization: "none" - residual_type: none + residual_type: none graph: pooling: [sum, max] out_dim: 1 @@ -122,7 +121,7 @@ architecture: # The parameters for the full graph network are taken from `co dropout: *dropout normalization: *normalization last_normalization: "none" - residual_type: none + residual_type: none edge: out_dim: 16 hidden_dims: 32 @@ -132,7 +131,7 @@ architecture: # The parameters for the full graph network are taken from `co dropout: *dropout normalization: *normalization last_normalization: "none" - residual_type: none + residual_type: none nodepair: out_dim: 16 hidden_dims: 32 @@ -142,7 +141,7 @@ architecture: # The parameters for the full graph network are taken from `co dropout: *dropout normalization: *normalization last_normalization: "none" - residual_type: none + residual_type: none task_heads: # Set as null to avoid task heads. Recall that the arguments for the TaskHeads is a List of TaskHeadParams task_1: diff --git a/graphium/config/fake_multilevel_multitask_pyg.yaml b/graphium/config/fake_multilevel_multitask_pyg.yaml index 3ca5085f9..918807cb4 100644 --- a/graphium/config/fake_multilevel_multitask_pyg.yaml +++ b/graphium/config/fake_multilevel_multitask_pyg.yaml @@ -58,7 +58,6 @@ datamodule: # Data handling-related batch_size_training: 16 batch_size_inference: 16 - # cache_data_path: null architecture: # The parameters for the full graph network are taken from `config_micro_ZINC.yaml` model_type: FullGraphMultiTaskNetwork @@ -111,7 +110,7 @@ architecture: # The parameters for the full graph network are taken from `co dropout: *dropout normalization: *normalization last_normalization: "none" - residual_type: none + residual_type: none graph: pooling: [sum, max] out_dim: 1 @@ -122,7 +121,7 @@ architecture: # The parameters for the full graph network are taken from `co dropout: *dropout normalization: *normalization last_normalization: "none" - residual_type: none + residual_type: none edge: out_dim: 16 hidden_dims: 32 @@ -132,7 +131,7 @@ architecture: # The parameters for the full graph network are taken from `co dropout: *dropout normalization: *normalization last_normalization: "none" - residual_type: none + residual_type: none nodepair: out_dim: 16 hidden_dims: 32 @@ -142,7 +141,7 @@ architecture: # The parameters for the full graph network are taken from `co dropout: *dropout normalization: *normalization last_normalization: "none" - residual_type: none + residual_type: none task_heads: # Set as null to avoid task heads. Recall that the arguments for the TaskHeads is a List of TaskHeadParams task_1: diff --git a/graphium/config/zinc_default_multitask_pyg.yaml b/graphium/config/zinc_default_multitask_pyg.yaml index 192d2c4ef..07ae4bf9b 100644 --- a/graphium/config/zinc_default_multitask_pyg.yaml +++ b/graphium/config/zinc_default_multitask_pyg.yaml @@ -58,7 +58,6 @@ datamodule: # Data handling-related batch_size_training: 16 batch_size_inference: 16 - # cache_data_path: null architecture: # The parameters for the full graph network are taken from `config_micro_ZINC.yaml` model_type: FullGraphMultiTaskNetwork diff --git a/profiling/configs_profiling.yaml b/profiling/configs_profiling.yaml index ba72c3b64..0ff4f6c94 100644 --- a/profiling/configs_profiling.yaml +++ b/profiling/configs_profiling.yaml @@ -6,7 +6,7 @@ datamodule: module_type: "DGLFromSmilesDataModule" args: df_path: https://storage.googleapis.com/graphium-public/datasets/graphium-zinc-bench-gnn/smiles_score.csv.gz - cache_data_path: null # graphium/data/cache/ZINC_bench_gnn/smiles_score.cache + processed_graph_data_path: null label_cols: ['score'] smiles_col: SMILES diff --git a/tests/config_test_ipu_dataloader_multitask.yaml b/tests/config_test_ipu_dataloader_multitask.yaml index 55d177622..8b8fbf417 100644 --- a/tests/config_test_ipu_dataloader_multitask.yaml +++ b/tests/config_test_ipu_dataloader_multitask.yaml @@ -130,7 +130,6 @@ datamodule: pos_type: rw_return_probs ksteps: 16 - # cache_data_path: . num_workers: -1 # -1 to use all persistent_workers: False # if use persistent worker at the start of each epoch. # Using persistent_workers false might make the start of each epoch very long. diff --git a/tests/data/config_micro_ZINC.yaml b/tests/data/config_micro_ZINC.yaml index e8b1a2b92..88fc4a841 100644 --- a/tests/data/config_micro_ZINC.yaml +++ b/tests/data/config_micro_ZINC.yaml @@ -6,7 +6,7 @@ datamodule: module_type: "DGLFromSmilesDataModule" args: df_path: graphium/data/micro_ZINC/micro_ZINC.csv - cache_data_path: graphium/data/cache/micro_ZINC/full.cache + processed_graph_data_path: graphium/data/cache/micro_ZINC/ label_cols: ['score'] smiles_col: SMILES From 3ebeffc75da62b089261713a05dd6d6bc088cbba Mon Sep 17 00:00:00 2001 From: DomInvivo Date: Thu, 10 Aug 2023 12:25:06 -0400 Subject: [PATCH 03/14] applied black linting --- profiling/profile_predictor.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/profiling/profile_predictor.py b/profiling/profile_predictor.py index 16df450c1..80ad284d4 100644 --- a/profiling/profile_predictor.py +++ b/profiling/profile_predictor.py @@ -20,7 +20,9 @@ def main(): with fsspec.open(CONFIG_PATH, "r") as f: cfg = yaml.safe_load(f) - cfg["datamodule"]["args"]["processed_graph_data_path"] = "graphium/data/cache/profiling/predictor_data.cache" + cfg["datamodule"]["args"][ + "processed_graph_data_path" + ] = "graphium/data/cache/profiling/predictor_data.cache" # cfg["datamodule"]["args"]["df_path"] = DATA_PATH cfg["trainer"]["trainer"]["max_epochs"] = 5 cfg["trainer"]["trainer"]["min_epochs"] = 5 From 4c08b5601ff3466de6cf6459810f9614a24605e8 Mon Sep 17 00:00:00 2001 From: DomInvivo Date: Thu, 10 Aug 2023 12:25:06 -0400 Subject: [PATCH 04/14] minor fix --- graphium/data/datamodule.py | 2 +- profiling/profile_predictor.py | 4 +++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/graphium/data/datamodule.py b/graphium/data/datamodule.py index d81249bbf..c898e09c8 100644 --- a/graphium/data/datamodule.py +++ b/graphium/data/datamodule.py @@ -1191,7 +1191,7 @@ def setup( labels_size = {} labels_dtype = {} if stage == "fit" or stage is None: - if self.processed_graph_data_path is not None: + if self.dataloading_from == "disk": processed_train_data_path = self._path_to_load_from_file("train") assert self._data_ready_at_path( processed_train_data_path diff --git a/profiling/profile_predictor.py b/profiling/profile_predictor.py index 16df450c1..80ad284d4 100644 --- a/profiling/profile_predictor.py +++ b/profiling/profile_predictor.py @@ -20,7 +20,9 @@ def main(): with fsspec.open(CONFIG_PATH, "r") as f: cfg = yaml.safe_load(f) - cfg["datamodule"]["args"]["processed_graph_data_path"] = "graphium/data/cache/profiling/predictor_data.cache" + cfg["datamodule"]["args"][ + "processed_graph_data_path" + ] = "graphium/data/cache/profiling/predictor_data.cache" # cfg["datamodule"]["args"]["df_path"] = DATA_PATH cfg["trainer"]["trainer"]["max_epochs"] = 5 cfg["trainer"]["trainer"]["min_epochs"] = 5 From 0497a53a217f88da5405c77d264b3f11c7ff93be Mon Sep 17 00:00:00 2001 From: WenkelF Date: Thu, 17 Aug 2023 10:53:27 -0400 Subject: [PATCH 05/14] Updating caching and dataloading from disk/ram --- graphium/data/datamodule.py | 94 +++++++++++++++++++------------------ graphium/data/dataset.py | 69 ++++++++++++++++++++++----- 2 files changed, 105 insertions(+), 58 deletions(-) diff --git a/graphium/data/datamodule.py b/graphium/data/datamodule.py index c898e09c8..1b68c50de 100644 --- a/graphium/data/datamodule.py +++ b/graphium/data/datamodule.py @@ -1,6 +1,7 @@ import tempfile from contextlib import redirect_stderr, redirect_stdout from typing import Type, List, Dict, Union, Any, Callable, Optional, Tuple, Iterable, Literal +from os import PathLike as Path from dataclasses import dataclass @@ -135,6 +136,7 @@ def __init__( self._predict_ds = None self._data_is_prepared = False + self._data_is_cached = False def prepare_data(self): raise NotImplementedError() @@ -932,6 +934,11 @@ def __init__( ) self.data_hash = self.get_data_hash() + if self.processed_graph_data_path is not None: + if self._ready_to_load_all_from_file(): + self._data_is_prepared = True + self._data_is_cached = True + def _parse_caching_args(self, processed_graph_data_path, dataloading_from): """ Parse the caching arguments, and raise errors if the arguments are invalid. @@ -994,15 +1001,10 @@ def has_atoms_after_h_removal(smiles): return has_atoms if self._data_is_prepared: - logger.info("Data is already prepared. Skipping the preparation") + logger.info("Data is already prepared.") + self.get_label_statistics(self.processed_graph_data_path, self.data_hash, dataset=None) return - if self.dataloading_from == "disk": - if self._ready_to_load_all_from_file(): - self.get_label_statistics(self.processed_graph_data_path, self.data_hash, dataset=None) - self._data_is_prepared = True - return - """Load all single-task dataframes.""" task_df = {} for task, args in self.task_dataset_processing_params.items(): @@ -1172,6 +1174,7 @@ def has_atoms_after_h_removal(smiles): if self.processed_graph_data_path is not None: self._save_data_to_files() + self._data_is_cached = True self._data_is_prepared = True @@ -1191,21 +1194,29 @@ def setup( labels_size = {} labels_dtype = {} if stage == "fit" or stage is None: - if self.dataloading_from == "disk": - processed_train_data_path = self._path_to_load_from_file("train") - assert self._data_ready_at_path( - processed_train_data_path - ), "Loading from file + setup() called but training data not ready" - processed_val_data_path = self._path_to_load_from_file("val") - assert self._data_ready_at_path( - processed_val_data_path - ), "Loading from file + setup() called but validation data not ready" - else: - processed_train_data_path = None - processed_val_data_path = None + # if self.dataloading_from == "disk": + # processed_train_data_path = self._path_to_load_from_file("train") + # assert self._data_ready_at_path( + # processed_train_data_path + # ), "Loading from file + setup() called but training data not ready" + # processed_val_data_path = self._path_to_load_from_file("val") + # assert self._data_ready_at_path( + # processed_val_data_path + # ), "Loading from file + setup() called but validation data not ready" + # else: + # processed_train_data_path = None + # processed_val_data_path = None + + # if not self._data_is_setup: + if self.train_ds is None: + self.train_ds = self._make_multitask_dataset( + self.dataloading_from, "train", save_smiles_and_ids=save_smiles_and_ids + ) + if self.val_ds is None: + self.val_ds = self._make_multitask_dataset( + self.dataloading_from, "val", save_smiles_and_ids=save_smiles_and_ids + ) - self.train_ds = self._make_multitask_dataset("train", save_smiles_and_ids=save_smiles_and_ids) - self.val_ds = self._make_multitask_dataset("val", save_smiles_and_ids=save_smiles_and_ids) logger.info(self.train_ds) logger.info(self.val_ds) labels_size.update( @@ -1216,14 +1227,11 @@ def setup( labels_dtype.update(self.val_ds.labels_dtype) if stage == "test" or stage is None: - if self.dataloading_from == "disk": - processed_test_data_path = self._path_to_load_from_file("test") - assert self._data_ready_at_path( - processed_test_data_path - ), "Loading from file + setup() called but test data not ready" - else: - processed_test_data_path = None - self.test_ds = self._make_multitask_dataset("test", save_smiles_and_ids=save_smiles_and_ids) + if self.test_ds is None: + self.test_ds = self._make_multitask_dataset( + self.dataloading_from, "test", save_smiles_and_ids=save_smiles_and_ids + ) + logger.info(self.test_ds) labels_size.update(self.test_ds.labels_size) @@ -1241,9 +1249,10 @@ def setup( def _make_multitask_dataset( self, + dataloading_from: Literal["disk", "ram"], stage: Literal["train", "val", "test"], save_smiles_and_ids: bool, - processed_graph_data_path: Optional[str] = None, + # processed_graph_data_path: Optional[str] = None, ) -> Datasets.MultitaskDataset: """ Create a MultitaskDataset for the given stage using single task datasets @@ -1270,18 +1279,7 @@ def _make_multitask_dataset( else: raise ValueError(f"Unknown stage {stage}") - if processed_graph_data_path is None: - processed_graph_data_path = self.processed_graph_data_path - - # assert singletask_datasets is not None, "Single task datasets must exist to make multitask dataset" - if singletask_datasets is None: - assert processed_graph_data_path is not None - assert self._data_ready_at_path( - self._path_to_load_from_file(stage) - ), "Trying to create multitask dataset without single-task datasets but data not ready" - files_ready = True - else: - files_ready = False + processed_graph_data_path = self.processed_graph_data_path multitask_dataset = Datasets.MultitaskDataset( singletask_datasets, @@ -1292,8 +1290,8 @@ def _make_multitask_dataset( about=about, save_smiles_and_ids=save_smiles_and_ids, data_path=self._path_to_load_from_file(stage) if processed_graph_data_path else None, - processed_graph_data_path=processed_graph_data_path, - files_ready=files_ready, + dataloading_from=dataloading_from, + data_is_cached=self._data_is_cached, ) # type: ignore # calculate statistics for the train split and used for all splits normalization @@ -1301,7 +1299,8 @@ def _make_multitask_dataset( self.get_label_statistics( self.processed_graph_data_path, self.data_hash, multitask_dataset, train=True ) - if self.dataloading_from == "ram": + # Normalization has already been applied in cached data + if not self._data_is_prepared: self.normalize_label(multitask_dataset, stage) return multitask_dataset @@ -1342,7 +1341,9 @@ def _save_data_to_files(self) -> None: # At the moment, we need to merge the `SingleTaskDataset`'s into `MultitaskDataset`s in order to save to file # This is because the combined labels need to be stored together. We can investigate not doing this if this is a problem temp_datasets = { - stage: self._make_multitask_dataset(stage, save_smiles_and_ids=False, load_from_file=False) + stage: self._make_multitask_dataset( + dataloading_from="ram", stage=stage, save_smiles_and_ids=False + ) for stage in stages } for stage in stages: @@ -1364,6 +1365,7 @@ def calculate_statistics(self, dataset: Datasets.MultitaskDataset, train: bool = train: whether the dataset is the training set """ + if self.task_norms and train: for task in dataset.labels_size.keys(): # if the label type is graph_*, we need to stack them as the tensor shape is (num_labels, ) diff --git a/graphium/data/dataset.py b/graphium/data/dataset.py index c89758d29..7b01713d7 100644 --- a/graphium/data/dataset.py +++ b/graphium/data/dataset.py @@ -8,6 +8,8 @@ import os import numpy as np +from datamol import parallelized, parallelized_with_batches + import torch from torch.utils.data.dataloader import Dataset from torch_geometric.data import Data, Batch @@ -147,7 +149,7 @@ def __init__( about: str = "", data_path: Optional[Union[str, os.PathLike]] = None, dataloading_from: str = "ram", - files_ready: bool = False, + data_is_cached: bool = False, ): r""" This class holds the information for the multitask dataset. @@ -176,7 +178,6 @@ def __init__( files_ready: Whether the files to load from were prepared ahead of time """ super().__init__() - # self.datasets = datasets self.n_jobs = n_jobs self.backend = backend self.featurization_batch_size = featurization_batch_size @@ -185,14 +186,17 @@ def __init__( self.data_path = data_path self.dataloading_from = dataloading_from - if files_ready: - if dataloading_from != "disk": - raise ValueError( - "Files are ready to be loaded from disk, but `dataloading_from` is not set to `disk`" - ) + logger.info(f"Dataloading from {dataloading_from.upper()}") + + if data_is_cached: self._load_metadata() - self.features = None - self.labels = None + + if dataloading_from == "disk": + self.features = None + self.labels = None + elif dataloading_from == "ram": + logger.info("Transferring data from DISK to RAM...") + self.transfer_from_disk_to_ram() else: task = next(iter(datasets)) @@ -213,9 +217,50 @@ def __init__( if self.features is not None: self._num_nodes_list = get_num_nodes_per_graph(self.features) self._num_edges_list = get_num_edges_per_graph(self.features) - if self.dataloading_from == "disk": - self.features = None - self.labels = None + + def transfer_from_disk_to_ram(self, parallel_with_batches: bool = False): + """ + Function parallelizing transfer from DISK to RAM + """ + + def transfer_mol_from_disk_to_ram(idx): + """ + Function transferring single mol from DISK to RAM + """ + data_dict = self.load_graph_from_index(idx) + mol_in_ram = {} + mol_in_ram.update({"features": data_dict["graph_with_features"]}) + mol_in_ram.update({"labels": data_dict["labels"]}) + if self.smiles is not None: + mol_in_ram.update({"smiles": data_dict["smiles"]}) + + return mol_in_ram + + if parallel_with_batches and self.featurization_batch_size: + data_in_ram = parallelized_with_batches( + transfer_mol_from_disk_to_ram, + range(self.dataset_length), + batch_size=self.featurization_batch_size, + n_jobs=self.n_jobs, + backend=self.backend, + progress=self.progress, + tqdm_kwargs={"desc": "Transfer from DISK to RAM"}, + ) + else: + data_in_ram = parallelized( + transfer_mol_from_disk_to_ram, + range(self.dataset_length), + n_jobs=self.n_jobs, + backend=self.backend, + progress=self.progress, + tqdm_kwargs={"desc": "Transfer from DISK to RAM"}, + ) + + self.features = [sample["features"] for sample in data_in_ram] + self.labels = [sample["labels"] for sample in data_in_ram] + self.smiles = None + if "smiles" in self.load_graph_from_index(0): + self.smiles = [sample["smiles"] for sample in data_in_ram] def save_metadata(self, directory: str): """ From f307ed5fb0e86605a791bce82498f90feb87d460 Mon Sep 17 00:00:00 2001 From: WenkelF Date: Thu, 17 Aug 2023 10:54:59 -0400 Subject: [PATCH 06/14] Adding option to prepare data in advance to CLI --- graphium/cli/prepare_data.py | 40 ++++++++++++++++++++++++++++++++++++ pyproject.toml | 1 + 2 files changed, 41 insertions(+) create mode 100644 graphium/cli/prepare_data.py diff --git a/graphium/cli/prepare_data.py b/graphium/cli/prepare_data.py new file mode 100644 index 000000000..50947378f --- /dev/null +++ b/graphium/cli/prepare_data.py @@ -0,0 +1,40 @@ +import hydra +import timeit + +from omegaconf import DictConfig, OmegaConf +from loguru import logger + +from graphium.config._loader import load_datamodule, load_accelerator + + +@hydra.main(version_base=None, config_path="../../expts/hydra-configs", config_name="main") +def cli(cfg: DictConfig) -> None: + """ + CLI endpoint for preparing the data in advance. + """ + run_prepare_data(cfg) + + +def run_prepare_data(cfg: DictConfig) -> None: + """ + The main (pre-)training and fine-tuning loop. + """ + + cfg = OmegaConf.to_container(cfg, resolve=True) + + st = timeit.default_timer() + + ## == Instantiate all required objects from their respective configs == + # Accelerator + cfg, accelerator_type = load_accelerator(cfg) + + ## Data-module + datamodule = load_datamodule(cfg, accelerator_type) + + datamodule.prepare_data() + + logger.info(f"Data preparation took {timeit.default_timer() - st:.2f} seconds.") + + +if __name__ == "__main__": + cli() diff --git a/pyproject.toml b/pyproject.toml index 9e55eb5f9..20cfa9792 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -65,6 +65,7 @@ dependencies = [ [project.scripts] graphium = "graphium.cli.main:main_cli" graphium-train = "graphium.cli.train_finetune:cli" + graphium-prepare-data = "graphium.cli.prepare_data:cli" [project.urls] Website = "https://graphium.datamol.io/" From d088ea28a862e9d2f2e7dae6493c00b73084cf05 Mon Sep 17 00:00:00 2001 From: WenkelF Date: Thu, 17 Aug 2023 10:55:22 -0400 Subject: [PATCH 07/14] Updating README --- README.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/README.md b/README.md index 018eee517..b80947ef9 100644 --- a/README.md +++ b/README.md @@ -111,6 +111,8 @@ graphium-train --config-path [PATH] --config-name [CONFIG] ``` Thanks to the modular nature of `hydra` you can reuse many of our config settings for your own experiments with Graphium. +## Preparing the data in advance +The data preparation including the featurization (e.g., of molecules from smiles to pyg-compatible format) is embedded in the pipeline and will be performed when executing `graphium-train [...]`. However, when working with larger datasets, it is recommended to perform data preparation in advance using a machine with sufficient allocated memory (e.g., ~400GB in the case of `LargeMix`). The command `graphium-prepare-data datamodule.args.processed_graph_data_path=[path_to_cached_data]` will prepare the data and cache it in the indicated location `[path_to_cached_data]`. The prepared data can be used for training via `graphium-train [...] datamodule.args.processed_graph_data_path=[path_to_cached_data]`. Note that `datamodule.args.processed_graph_data_path` can also be specified at `expts/hydra_configs/`. ## First Time Running on IPUs For new IPU developers this section helps provide some more explanation on how to set up an environment to use Graphcore IPUs with Graphium. From b4eaff7b64f9ed711a118793fcc8640d128b6e9a Mon Sep 17 00:00:00 2001 From: DomInvivo <47570400+DomInvivo@users.noreply.github.com> Date: Thu, 17 Aug 2023 11:22:58 -0400 Subject: [PATCH 08/14] Update README.md --- README.md | 17 ++++++++++++++++- 1 file changed, 16 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index b80947ef9..1299e5f5b 100644 --- a/README.md +++ b/README.md @@ -112,7 +112,22 @@ graphium-train --config-path [PATH] --config-name [CONFIG] Thanks to the modular nature of `hydra` you can reuse many of our config settings for your own experiments with Graphium. ## Preparing the data in advance -The data preparation including the featurization (e.g., of molecules from smiles to pyg-compatible format) is embedded in the pipeline and will be performed when executing `graphium-train [...]`. However, when working with larger datasets, it is recommended to perform data preparation in advance using a machine with sufficient allocated memory (e.g., ~400GB in the case of `LargeMix`). The command `graphium-prepare-data datamodule.args.processed_graph_data_path=[path_to_cached_data]` will prepare the data and cache it in the indicated location `[path_to_cached_data]`. The prepared data can be used for training via `graphium-train [...] datamodule.args.processed_graph_data_path=[path_to_cached_data]`. Note that `datamodule.args.processed_graph_data_path` can also be specified at `expts/hydra_configs/`. +The data preparation including the featurization (e.g., of molecules from smiles to pyg-compatible format) is embedded in the pipeline and will be performed when executing `graphium-train [...]`. + +However, when working with larger datasets, it is recommended to perform data preparation in advance using a machine with sufficient allocated memory (e.g., ~400GB in the case of `LargeMix`). Preparing data in advance is also beneficial when running lots of concurrent jobs with identical molecular featurization, so that resources aren't wasted and processes don't conflict reading/writing in the same directory. + +The following command-line will prepare the data and cache it, then use it to train a model. +```bash +# First prepare the data and cache it in `path_to_cached_data` +graphium-prepare-data datamodule.args.processed_graph_data_path=[path_to_cached_data] + +# Then train the model on the prepared data +graphium-train [...] datamodule.args.processed_graph_data_path=[path_to_cached_data] +``` + +**Note** that `datamodule.args.processed_graph_data_path` can also be specified at `expts/hydra_configs/`. + +**Note** that, every time the configs of `datamodule.args.featurization` changes, you will need to run a new data preparation, which will automatically be saved in a separate directory that uses a hash unique to the configs. ## First Time Running on IPUs For new IPU developers this section helps provide some more explanation on how to set up an environment to use Graphcore IPUs with Graphium. From 5a23093bad3023189a153ee7d431ac040bfd814c Mon Sep 17 00:00:00 2001 From: DomInvivo Date: Thu, 17 Aug 2023 19:39:31 -0400 Subject: [PATCH 09/14] Fixed the PR comments --- graphium/cli/prepare_data.py | 16 +++++++++------- graphium/data/datamodule.py | 14 -------------- 2 files changed, 9 insertions(+), 21 deletions(-) diff --git a/graphium/cli/prepare_data.py b/graphium/cli/prepare_data.py index 50947378f..7a8c6eceb 100644 --- a/graphium/cli/prepare_data.py +++ b/graphium/cli/prepare_data.py @@ -21,16 +21,18 @@ def run_prepare_data(cfg: DictConfig) -> None: """ cfg = OmegaConf.to_container(cfg, resolve=True) - st = timeit.default_timer() - ## == Instantiate all required objects from their respective configs == - # Accelerator - cfg, accelerator_type = load_accelerator(cfg) - - ## Data-module - datamodule = load_datamodule(cfg, accelerator_type) + # Checking that `processed_graph_data_path` is provided + path = cfg["datamodule"]["args"].get("processed_graph_data_path", None) + if path is None: + raise ValueError( + "Please provide `datamodule.args.processed_graph_data_path` to specify the caching dir." + ) + logger.info(f"The caching dir is set to '{path}'") + # Data-module + datamodule = load_datamodule(cfg, "cpu") datamodule.prepare_data() logger.info(f"Data preparation took {timeit.default_timer() - st:.2f} seconds.") diff --git a/graphium/data/datamodule.py b/graphium/data/datamodule.py index 1b68c50de..03a8425ce 100644 --- a/graphium/data/datamodule.py +++ b/graphium/data/datamodule.py @@ -1194,20 +1194,6 @@ def setup( labels_size = {} labels_dtype = {} if stage == "fit" or stage is None: - # if self.dataloading_from == "disk": - # processed_train_data_path = self._path_to_load_from_file("train") - # assert self._data_ready_at_path( - # processed_train_data_path - # ), "Loading from file + setup() called but training data not ready" - # processed_val_data_path = self._path_to_load_from_file("val") - # assert self._data_ready_at_path( - # processed_val_data_path - # ), "Loading from file + setup() called but validation data not ready" - # else: - # processed_train_data_path = None - # processed_val_data_path = None - - # if not self._data_is_setup: if self.train_ds is None: self.train_ds = self._make_multitask_dataset( self.dataloading_from, "train", save_smiles_and_ids=save_smiles_and_ids From ef1e30f8640ffed77bc662aacbdd689381a727d2 Mon Sep 17 00:00:00 2001 From: WenkelF Date: Fri, 18 Aug 2023 12:07:00 -0400 Subject: [PATCH 10/14] Fixing unit tests --- .../choosing_parallelization.ipynb | 22 ++-- expts/hydra-configs/architecture/toymix.yaml | 7 +- .../tasks/loss_metrics_datamodule/toymix.yaml | 5 +- .../training/accelerator/toymix_cpu.yaml | 2 +- graphium/data/datamodule.py | 34 +++--- graphium/data/dataset.py | 29 +++-- tests/test_datamodule.py | 115 +++++++++++++++--- 7 files changed, 154 insertions(+), 60 deletions(-) diff --git a/docs/tutorials/feature_processing/choosing_parallelization.ipynb b/docs/tutorials/feature_processing/choosing_parallelization.ipynb index 1ebb54451..0ab569d57 100644 --- a/docs/tutorials/feature_processing/choosing_parallelization.ipynb +++ b/docs/tutorials/feature_processing/choosing_parallelization.ipynb @@ -14,7 +14,7 @@ }, { "cell_type": "code", - "execution_count": 1, + "execution_count": 3, "id": "b5df2ac6-2ded-4597-a445-f2b5fb106330", "metadata": { "tags": [] @@ -24,8 +24,8 @@ "name": "stdout", "output_type": "stream", "text": [ - "INFO: Pandarallel will run on 240 workers.\n", - "INFO: Pandarallel will use Memory file system to transfer data between the main process and workers.\n" + "The autoreload extension is already loaded. To reload it, use:\n", + " %reload_ext autoreload\n" ] } ], @@ -39,9 +39,9 @@ "import datamol as dm\n", "import pandas as pd\n", "\n", - "from pandarallel import pandarallel\n", + "# from pandarallel import pandarallel\n", "\n", - "pandarallel.initialize(progress_bar=True, nb_workers=joblib.cpu_count())" + "# pandarallel.initialize(progress_bar=True, nb_workers=joblib.cpu_count())" ] }, { @@ -54,7 +54,7 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": 4, "id": "0f31e18d-bdd9-4d9b-8ba5-81e5887b857e", "metadata": { "tags": [] @@ -70,7 +70,7 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 7, "id": "a1197c31-7dbc-4fd7-a69a-5215e1a96b8e", "metadata": { "tags": [] @@ -109,7 +109,7 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 10, "id": "2f8ce5c3-4232-4279-8ea3-7a74832303be", "metadata": { "tags": [] @@ -129,7 +129,7 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 11, "id": "a246cdcf-b5ea-4c9e-9ccc-dd3c544587bb", "metadata": { "tags": [] @@ -138,7 +138,7 @@ { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "3e939cd3a24742038b804bbfd961377d", + "model_id": "cc396220c7144c8d8b195fb87694bbfe", "version_major": 2, "version_minor": 0 }, @@ -489,7 +489,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.8.10" + "version": "3.10.12" }, "widgets": { "application/vnd.jupyter.widget-state+json": { diff --git a/expts/hydra-configs/architecture/toymix.yaml b/expts/hydra-configs/architecture/toymix.yaml index 6927f4e66..1523a93be 100644 --- a/expts/hydra-configs/architecture/toymix.yaml +++ b/expts/hydra-configs/architecture/toymix.yaml @@ -75,11 +75,12 @@ datamodule: module_type: "MultitaskFromSmilesDataModule" args: prepare_dict_or_graph: pyg:graph - featurization_n_jobs: 30 + featurization_n_jobs: 0 featurization_progress: True featurization_backend: "loky" - processed_graph_data_path: "../datacache/neurips2023-small/" - num_workers: 30 # -1 to use all + processed_graph_data_path: "../datacache/dummy-toymix/" + dataloading_from: ram + num_workers: 0 # -1 to use all persistent_workers: False featurization: atom_property_list_onehot: [atomic-number, group, period, total-valence] diff --git a/expts/hydra-configs/tasks/loss_metrics_datamodule/toymix.yaml b/expts/hydra-configs/tasks/loss_metrics_datamodule/toymix.yaml index 9ac744a52..575c95561 100644 --- a/expts/hydra-configs/tasks/loss_metrics_datamodule/toymix.yaml +++ b/expts/hydra-configs/tasks/loss_metrics_datamodule/toymix.yaml @@ -73,6 +73,7 @@ datamodule: label_normalization: normalize_val_test: True method: "normal" + sample_size: 200 tox21: df: null @@ -85,6 +86,7 @@ datamodule: splits_path: ${constants.data_dir}/Tox21_random_splits.pt # Download with `wget https://storage.googleapis.com/graphium-public/datasets/neurips_2023/Small-dataset/Tox21_random_splits.pt` seed: ${constants.seed} task_level: graph + sample_size: 200 zinc: df: null @@ -99,4 +101,5 @@ datamodule: task_level: graph label_normalization: normalize_val_test: True - method: "normal" \ No newline at end of file + method: "normal" + sample_size: 200 \ No newline at end of file diff --git a/expts/hydra-configs/training/accelerator/toymix_cpu.yaml b/expts/hydra-configs/training/accelerator/toymix_cpu.yaml index 9022eeb84..7f11d3831 100644 --- a/expts/hydra-configs/training/accelerator/toymix_cpu.yaml +++ b/expts/hydra-configs/training/accelerator/toymix_cpu.yaml @@ -5,7 +5,7 @@ datamodule: batch_size_training: 200 batch_size_inference: 200 featurization_n_jobs: 4 - num_workers: 4 + num_workers: 0 predictor: optim_kwargs: {} diff --git a/graphium/data/datamodule.py b/graphium/data/datamodule.py index 03a8425ce..e8cab271d 100644 --- a/graphium/data/datamodule.py +++ b/graphium/data/datamodule.py @@ -976,7 +976,7 @@ def get_task_levels(self): return task_level_map - def prepare_data(self): + def prepare_data(self, save_smiles_and_ids: bool = False): """Called only from a single process in distributed settings. Steps: - If each cache is set and exists, reload from cache and return. Otherwise, @@ -1173,7 +1173,7 @@ def has_atoms_after_h_removal(smiles): ) if self.processed_graph_data_path is not None: - self._save_data_to_files() + self._save_data_to_files(save_smiles_and_ids) self._data_is_cached = True self._data_is_prepared = True @@ -1194,14 +1194,14 @@ def setup( labels_size = {} labels_dtype = {} if stage == "fit" or stage is None: - if self.train_ds is None: - self.train_ds = self._make_multitask_dataset( - self.dataloading_from, "train", save_smiles_and_ids=save_smiles_and_ids - ) - if self.val_ds is None: - self.val_ds = self._make_multitask_dataset( - self.dataloading_from, "val", save_smiles_and_ids=save_smiles_and_ids - ) + # if self.train_ds is None: + self.train_ds = self._make_multitask_dataset( + self.dataloading_from, "train", save_smiles_and_ids=save_smiles_and_ids + ) + # if self.val_ds is None: + self.val_ds = self._make_multitask_dataset( + self.dataloading_from, "val", save_smiles_and_ids=save_smiles_and_ids + ) logger.info(self.train_ds) logger.info(self.val_ds) @@ -1213,10 +1213,10 @@ def setup( labels_dtype.update(self.val_ds.labels_dtype) if stage == "test" or stage is None: - if self.test_ds is None: - self.test_ds = self._make_multitask_dataset( - self.dataloading_from, "test", save_smiles_and_ids=save_smiles_and_ids - ) + # if self.test_ds is None: + self.test_ds = self._make_multitask_dataset( + self.dataloading_from, "test", save_smiles_and_ids=save_smiles_and_ids + ) logger.info(self.test_ds) @@ -1238,7 +1238,6 @@ def _make_multitask_dataset( dataloading_from: Literal["disk", "ram"], stage: Literal["train", "val", "test"], save_smiles_and_ids: bool, - # processed_graph_data_path: Optional[str] = None, ) -> Datasets.MultitaskDataset: """ Create a MultitaskDataset for the given stage using single task datasets @@ -1317,7 +1316,7 @@ def _data_ready_at_path(self, path: str) -> bool: return can_load_from_file - def _save_data_to_files(self) -> None: + def _save_data_to_files(self, save_smiles_and_ids: bool = False) -> None: """ Save data to files so that they can be loaded from file during training/validation/test """ @@ -1328,7 +1327,7 @@ def _save_data_to_files(self) -> None: # This is because the combined labels need to be stored together. We can investigate not doing this if this is a problem temp_datasets = { stage: self._make_multitask_dataset( - dataloading_from="ram", stage=stage, save_smiles_and_ids=False + dataloading_from="ram", stage=stage, save_smiles_and_ids=save_smiles_and_ids ) for stage in stages } @@ -2220,6 +2219,7 @@ def __init__( dm_args = {} dm_args["task_specific_args"] = new_task_specific_args dm_args["processed_graph_data_path"] = processed_graph_data_path + dm_args["dataloading_from"] = dataloading_from dm_args["dataloader_from"] = dataloading_from dm_args["featurization"] = featurization dm_args["batch_size_training"] = batch_size_training diff --git a/graphium/data/dataset.py b/graphium/data/dataset.py index 7b01713d7..039d1b35a 100644 --- a/graphium/data/dataset.py +++ b/graphium/data/dataset.py @@ -171,11 +171,9 @@ def __init__( progress: Whether to display the progress bar save_smiles_and_ids: Whether to save the smiles and ids for the dataset. If `False`, `mol_ids` and `smiles` are set to `None` about: A description of the dataset - progress: Whether to display the progress bar - about: A description of the dataset data_path: The location of the data if saved on disk dataloading_from: Whether to load the data from `"disk"` or `"ram"` - files_ready: Whether the files to load from were prepared ahead of time + data_is_cached: Whether the data is already cached on `"disk"` """ super().__init__() self.n_jobs = n_jobs @@ -183,6 +181,7 @@ def __init__( self.featurization_batch_size = featurization_batch_size self.progress = progress self.about = about + self.save_smiles_and_ids = save_smiles_and_ids self.data_path = data_path self.dataloading_from = dataloading_from @@ -228,11 +227,10 @@ def transfer_mol_from_disk_to_ram(idx): Function transferring single mol from DISK to RAM """ data_dict = self.load_graph_from_index(idx) - mol_in_ram = {} - mol_in_ram.update({"features": data_dict["graph_with_features"]}) - mol_in_ram.update({"labels": data_dict["labels"]}) - if self.smiles is not None: - mol_in_ram.update({"smiles": data_dict["smiles"]}) + mol_in_ram = { + "features": data_dict["graph_with_features"], + "labels": data_dict["labels"], + } return mol_in_ram @@ -241,7 +239,7 @@ def transfer_mol_from_disk_to_ram(idx): transfer_mol_from_disk_to_ram, range(self.dataset_length), batch_size=self.featurization_batch_size, - n_jobs=self.n_jobs, + n_jobs=0, backend=self.backend, progress=self.progress, tqdm_kwargs={"desc": "Transfer from DISK to RAM"}, @@ -250,7 +248,7 @@ def transfer_mol_from_disk_to_ram(idx): data_in_ram = parallelized( transfer_mol_from_disk_to_ram, range(self.dataset_length), - n_jobs=self.n_jobs, + n_jobs=0, backend=self.backend, progress=self.progress, tqdm_kwargs={"desc": "Transfer from DISK to RAM"}, @@ -258,9 +256,6 @@ def transfer_mol_from_disk_to_ram(idx): self.features = [sample["features"] for sample in data_in_ram] self.labels = [sample["labels"] for sample in data_in_ram] - self.smiles = None - if "smiles" in self.load_graph_from_index(0): - self.smiles = [sample["smiles"] for sample in data_in_ram] def save_metadata(self, directory: str): """ @@ -309,6 +304,14 @@ def _load_metadata(self): for attr, value in attrs.items(): setattr(self, attr, value) + if self.save_smiles_and_ids: + if self.smiles is None or self.mol_ids is None: + logger.warning( + f"Argument `save_smiles_and_ids` is set to {self.save_smiles_and_ids} but metadata in the cache at {self.data_path} does not contain smiles and mol_ids. " + f"This may be because `Datamodule.prepare_data(save_smiles_and_ids=False)` was run followed by `Datamodule.setup(save_smiles_and_ids=True)`. " + f"When loading from cached files, the `save_smiles_and_ids` argument of `Datamodule.setup()` is superseeded by the `Datamodule.prepare_data()`. " + ) + def __len__(self): r""" Returns the number of molecules diff --git a/tests/test_datamodule.py b/tests/test_datamodule.py index 45cd30106..bbf531527 100644 --- a/tests/test_datamodule.py +++ b/tests/test_datamodule.py @@ -37,19 +37,30 @@ def test_ogb_datamodule(self): dm_args["batch_size_inference"] = 16 dm_args["num_workers"] = 0 dm_args["pin_memory"] = True - dm_args["featurization_n_jobs"] = 2 + dm_args["featurization_n_jobs"] = 0 dm_args["featurization_progress"] = True dm_args["featurization_backend"] = "loky" dm_args["featurization_batch_size"] = 50 ds = GraphOGBDataModule(task_specific_args, **dm_args) - ds.prepare_data() + ds.prepare_data(save_smiles_and_ids=False) # Check the keys in the dataset ds.setup(save_smiles_and_ids=False) assert set(ds.train_ds[0].keys()) == {"features", "labels"} + # Delete the cache if already exist + if exists(TEMP_CACHE_DATA_PATH): + rm(TEMP_CACHE_DATA_PATH, recursive=True) + + # Reset the datamodule + ds._data_is_prepared = False + ds._data_is_cached = False + + ds.prepare_data(save_smiles_and_ids=True) + + # Check the keys in the dataset ds.setup(save_smiles_and_ids=True) assert set(ds.train_ds[0].keys()) == {"smiles", "mol_ids", "features", "labels"} @@ -163,7 +174,6 @@ def test_caching(self): featurization_args = {} featurization_args["atom_property_list_float"] = [] # ["weight", "valence"] featurization_args["atom_property_list_onehot"] = ["atomic-number", "degree"] - # featurization_args["conformer_property_list"] = ["positions_3d"] featurization_args["edge_property_list"] = ["bond-type-onehot"] featurization_args["add_self_loop"] = False featurization_args["use_bonds_weights"] = False @@ -178,7 +188,7 @@ def test_caching(self): dm_args["batch_size_inference"] = 16 dm_args["num_workers"] = 0 dm_args["pin_memory"] = True - dm_args["featurization_n_jobs"] = 2 + dm_args["featurization_n_jobs"] = 0 dm_args["featurization_progress"] = True dm_args["featurization_backend"] = "loky" dm_args["featurization_batch_size"] = 50 @@ -190,23 +200,96 @@ def test_caching(self): # Prepare the data. It should create the cache there assert not exists(TEMP_CACHE_DATA_PATH) ds = GraphOGBDataModule(task_specific_args, processed_graph_data_path=TEMP_CACHE_DATA_PATH, **dm_args) - assert not ds.load_data_from_cache(verbose=False) - ds.prepare_data() + # assert not ds.load_data_from_cache(verbose=False) + ds.prepare_data(save_smiles_and_ids=False) # Check the keys in the dataset ds.setup(save_smiles_and_ids=False) assert set(ds.train_ds[0].keys()) == {"features", "labels"} - ds.setup(save_smiles_and_ids=True) - assert set(ds.train_ds[0].keys()) == {"smiles", "mol_ids", "features", "labels"} + # ds_batch = next(iter(ds.train_dataloader())) + train_loader = ds.get_dataloader(ds.train_ds, shuffle=False, stage="train") + batch = next(iter(train_loader)) + + # Test loading cached data + assert exists(TEMP_CACHE_DATA_PATH) + + cached_ds_from_ram = GraphOGBDataModule( + task_specific_args, + processed_graph_data_path=TEMP_CACHE_DATA_PATH, + dataloading_from="ram", + **dm_args, + ) + cached_ds_from_ram.prepare_data() + cached_ds_from_ram.setup() + cached_train_loader_from_ram = cached_ds_from_ram.get_dataloader( + cached_ds_from_ram.train_ds, shuffle=False, stage="train" + ) + batch_from_ram = next(iter(cached_train_loader_from_ram)) + + cached_ds_from_disk = GraphOGBDataModule( + task_specific_args, + processed_graph_data_path=TEMP_CACHE_DATA_PATH, + dataloading_from="disk", + **dm_args, + ) + cached_ds_from_disk.prepare_data() + cached_ds_from_disk.setup() + cached_train_loader_from_disk = cached_ds_from_disk.get_dataloader( + cached_ds_from_disk.train_ds, shuffle=False, stage="train" + ) + batch_from_disk = next(iter(cached_train_loader_from_disk)) + + # Features are the same + assert torch.equal(batch["features"].edge_index, batch_from_ram["features"].edge_index) + assert torch.equal(batch["features"].edge_index, batch_from_disk["features"].edge_index) + + assert batch["features"].num_nodes == batch_from_ram["features"].num_nodes + assert batch["features"].num_nodes == batch_from_disk["features"].num_nodes - # Make sure that the cache is created - full_cache_path = ds.get_data_cache_fullname(compress=False) - assert exists(full_cache_path) - assert get_size(full_cache_path) > 10000 + assert torch.equal(batch["features"].edge_weight, batch_from_ram["features"].edge_weight) + assert torch.equal(batch["features"].edge_weight, batch_from_disk["features"].edge_weight) - # Check that the data is loaded correctly from cache - assert ds.load_data_from_cache(verbose=False) + assert torch.equal(batch["features"].feat, batch_from_ram["features"].feat) + assert torch.equal(batch["features"].feat, batch_from_disk["features"].feat) + + assert torch.equal(batch["features"].edge_feat, batch_from_ram["features"].edge_feat) + assert torch.equal(batch["features"].edge_feat, batch_from_disk["features"].edge_feat) + + assert torch.equal(batch["features"].batch, batch_from_ram["features"].batch) + assert torch.equal(batch["features"].batch, batch_from_disk["features"].batch) + + assert torch.equal(batch["features"].ptr, batch_from_ram["features"].ptr) + assert torch.equal(batch["features"].ptr, batch_from_disk["features"].ptr) + + # Labels are the same + assert torch.equal(batch["labels"].graph_task_1, batch_from_ram["labels"].graph_task_1) + assert torch.equal(batch["labels"].graph_task_1, batch_from_disk["labels"].graph_task_1) + + assert torch.equal(batch["labels"].x, batch_from_ram["labels"].x) + assert torch.equal(batch["labels"].x, batch_from_disk["labels"].x) + + assert torch.equal(batch["labels"].edge_index, batch_from_ram["labels"].edge_index) + assert torch.equal(batch["labels"].edge_index, batch_from_disk["labels"].edge_index) + + assert torch.equal(batch["labels"].batch, batch_from_ram["labels"].batch) + assert torch.equal(batch["labels"].batch, batch_from_disk["labels"].batch) + + assert torch.equal(batch["labels"].ptr, batch_from_ram["labels"].ptr) + assert torch.equal(batch["labels"].ptr, batch_from_disk["labels"].ptr) + + # Delete the cache if already exist + if exists(TEMP_CACHE_DATA_PATH): + rm(TEMP_CACHE_DATA_PATH, recursive=True) + + # Reset the datamodule + ds._data_is_prepared = False + ds._data_is_cached = False + + ds.prepare_data(save_smiles_and_ids=True) + + ds.setup(save_smiles_and_ids=True) + assert set(ds.train_ds[0].keys()) == {"smiles", "mol_ids", "features", "labels"} # test module assert ds.num_edge_feats == 5 @@ -219,6 +302,10 @@ def test_caching(self): assert len(batch["labels"]["graph_task_1"]) == 16 assert len(batch["mol_ids"]) == 16 + # Delete the cache if already exist + if exists(TEMP_CACHE_DATA_PATH): + rm(TEMP_CACHE_DATA_PATH, recursive=True) + def test_datamodule_with_none_molecules(self): # Setup the featurization featurization_args = {} From d49241778626e29e254ef7987d6012737c2a8ca2 Mon Sep 17 00:00:00 2001 From: WenkelF Date: Fri, 18 Aug 2023 12:22:50 -0400 Subject: [PATCH 11/14] Undoing some unintentional changes --- expts/hydra-configs/architecture/toymix.yaml | 6 +++--- .../hydra-configs/tasks/loss_metrics_datamodule/toymix.yaml | 5 +---- 2 files changed, 4 insertions(+), 7 deletions(-) diff --git a/expts/hydra-configs/architecture/toymix.yaml b/expts/hydra-configs/architecture/toymix.yaml index 1523a93be..c79325919 100644 --- a/expts/hydra-configs/architecture/toymix.yaml +++ b/expts/hydra-configs/architecture/toymix.yaml @@ -75,12 +75,12 @@ datamodule: module_type: "MultitaskFromSmilesDataModule" args: prepare_dict_or_graph: pyg:graph - featurization_n_jobs: 0 + featurization_n_jobs: 30 featurization_progress: True featurization_backend: "loky" - processed_graph_data_path: "../datacache/dummy-toymix/" + processed_graph_data_path: "../datacache/neurips2023-small/" dataloading_from: ram - num_workers: 0 # -1 to use all + num_workers: 30 # -1 to use all persistent_workers: False featurization: atom_property_list_onehot: [atomic-number, group, period, total-valence] diff --git a/expts/hydra-configs/tasks/loss_metrics_datamodule/toymix.yaml b/expts/hydra-configs/tasks/loss_metrics_datamodule/toymix.yaml index 575c95561..9ac744a52 100644 --- a/expts/hydra-configs/tasks/loss_metrics_datamodule/toymix.yaml +++ b/expts/hydra-configs/tasks/loss_metrics_datamodule/toymix.yaml @@ -73,7 +73,6 @@ datamodule: label_normalization: normalize_val_test: True method: "normal" - sample_size: 200 tox21: df: null @@ -86,7 +85,6 @@ datamodule: splits_path: ${constants.data_dir}/Tox21_random_splits.pt # Download with `wget https://storage.googleapis.com/graphium-public/datasets/neurips_2023/Small-dataset/Tox21_random_splits.pt` seed: ${constants.seed} task_level: graph - sample_size: 200 zinc: df: null @@ -101,5 +99,4 @@ datamodule: task_level: graph label_normalization: normalize_val_test: True - method: "normal" - sample_size: 200 \ No newline at end of file + method: "normal" \ No newline at end of file From 7665b67a5fc4bd0d28856d79a487aec7dc59c24d Mon Sep 17 00:00:00 2001 From: WenkelF Date: Fri, 18 Aug 2023 12:25:43 -0400 Subject: [PATCH 12/14] Undoing some unintentional changes --- expts/hydra-configs/training/accelerator/toymix_cpu.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/expts/hydra-configs/training/accelerator/toymix_cpu.yaml b/expts/hydra-configs/training/accelerator/toymix_cpu.yaml index 7f11d3831..9022eeb84 100644 --- a/expts/hydra-configs/training/accelerator/toymix_cpu.yaml +++ b/expts/hydra-configs/training/accelerator/toymix_cpu.yaml @@ -5,7 +5,7 @@ datamodule: batch_size_training: 200 batch_size_inference: 200 featurization_n_jobs: 4 - num_workers: 0 + num_workers: 4 predictor: optim_kwargs: {} From 4240bbbda4fdb562173b203fd78554c8aa3117c1 Mon Sep 17 00:00:00 2001 From: WenkelF Date: Fri, 18 Aug 2023 13:40:04 -0400 Subject: [PATCH 13/14] Fixing datamodule unit test and unit test speedup --- tests/test_datamodule.py | 52 ++++++++++++++++++++-------------------- 1 file changed, 26 insertions(+), 26 deletions(-) diff --git a/tests/test_datamodule.py b/tests/test_datamodule.py index bbf531527..510658f50 100644 --- a/tests/test_datamodule.py +++ b/tests/test_datamodule.py @@ -241,42 +241,42 @@ def test_caching(self): batch_from_disk = next(iter(cached_train_loader_from_disk)) # Features are the same - assert torch.equal(batch["features"].edge_index, batch_from_ram["features"].edge_index) - assert torch.equal(batch["features"].edge_index, batch_from_disk["features"].edge_index) + np.testing.assert_array_almost_equal(batch["features"].edge_index, batch_from_ram["features"].edge_index) + np.testing.assert_array_almost_equal(batch["features"].edge_index, batch_from_disk["features"].edge_index) assert batch["features"].num_nodes == batch_from_ram["features"].num_nodes assert batch["features"].num_nodes == batch_from_disk["features"].num_nodes - assert torch.equal(batch["features"].edge_weight, batch_from_ram["features"].edge_weight) - assert torch.equal(batch["features"].edge_weight, batch_from_disk["features"].edge_weight) + np.testing.assert_array_almost_equal(batch["features"].edge_weight, batch_from_ram["features"].edge_weight) + np.testing.assert_array_almost_equal(batch["features"].edge_weight, batch_from_disk["features"].edge_weight) - assert torch.equal(batch["features"].feat, batch_from_ram["features"].feat) - assert torch.equal(batch["features"].feat, batch_from_disk["features"].feat) + np.testing.assert_array_almost_equal(batch["features"].feat, batch_from_ram["features"].feat) + np.testing.assert_array_almost_equal(batch["features"].feat, batch_from_disk["features"].feat) - assert torch.equal(batch["features"].edge_feat, batch_from_ram["features"].edge_feat) - assert torch.equal(batch["features"].edge_feat, batch_from_disk["features"].edge_feat) + np.testing.assert_array_almost_equal(batch["features"].edge_feat, batch_from_ram["features"].edge_feat) + np.testing.assert_array_almost_equal(batch["features"].edge_feat, batch_from_disk["features"].edge_feat) - assert torch.equal(batch["features"].batch, batch_from_ram["features"].batch) - assert torch.equal(batch["features"].batch, batch_from_disk["features"].batch) + np.testing.assert_array_almost_equal(batch["features"].batch, batch_from_ram["features"].batch) + np.testing.assert_array_almost_equal(batch["features"].batch, batch_from_disk["features"].batch) - assert torch.equal(batch["features"].ptr, batch_from_ram["features"].ptr) - assert torch.equal(batch["features"].ptr, batch_from_disk["features"].ptr) + np.testing.assert_array_almost_equal(batch["features"].ptr, batch_from_ram["features"].ptr) + np.testing.assert_array_almost_equal(batch["features"].ptr, batch_from_disk["features"].ptr) # Labels are the same - assert torch.equal(batch["labels"].graph_task_1, batch_from_ram["labels"].graph_task_1) - assert torch.equal(batch["labels"].graph_task_1, batch_from_disk["labels"].graph_task_1) + np.testing.assert_array_almost_equal(batch["labels"].graph_task_1, batch_from_ram["labels"].graph_task_1) + np.testing.assert_array_almost_equal(batch["labels"].graph_task_1, batch_from_disk["labels"].graph_task_1) - assert torch.equal(batch["labels"].x, batch_from_ram["labels"].x) - assert torch.equal(batch["labels"].x, batch_from_disk["labels"].x) + np.testing.assert_array_almost_equal(batch["labels"].x, batch_from_ram["labels"].x) + np.testing.assert_array_almost_equal(batch["labels"].x, batch_from_disk["labels"].x) - assert torch.equal(batch["labels"].edge_index, batch_from_ram["labels"].edge_index) - assert torch.equal(batch["labels"].edge_index, batch_from_disk["labels"].edge_index) + np.testing.assert_array_almost_equal(batch["labels"].edge_index, batch_from_ram["labels"].edge_index) + np.testing.assert_array_almost_equal(batch["labels"].edge_index, batch_from_disk["labels"].edge_index) - assert torch.equal(batch["labels"].batch, batch_from_ram["labels"].batch) - assert torch.equal(batch["labels"].batch, batch_from_disk["labels"].batch) + np.testing.assert_array_almost_equal(batch["labels"].batch, batch_from_ram["labels"].batch) + np.testing.assert_array_almost_equal(batch["labels"].batch, batch_from_disk["labels"].batch) - assert torch.equal(batch["labels"].ptr, batch_from_ram["labels"].ptr) - assert torch.equal(batch["labels"].ptr, batch_from_disk["labels"].ptr) + np.testing.assert_array_almost_equal(batch["labels"].ptr, batch_from_ram["labels"].ptr) + np.testing.assert_array_almost_equal(batch["labels"].ptr, batch_from_disk["labels"].ptr) # Delete the cache if already exist if exists(TEMP_CACHE_DATA_PATH): @@ -422,7 +422,7 @@ def test_datamodule_multiple_data_files(self): "task": {"task_level": "graph", "label_cols": ["score"], "smiles_col": "SMILES", **task_kwargs} } - ds = MultitaskFromSmilesDataModule(task_specific_args) + ds = MultitaskFromSmilesDataModule(task_specific_args, featurization_n_jobs=0) ds.prepare_data() ds.setup() @@ -435,7 +435,7 @@ def test_datamodule_multiple_data_files(self): "task": {"task_level": "graph", "label_cols": ["score"], "smiles_col": "SMILES", **task_kwargs} } - ds = MultitaskFromSmilesDataModule(task_specific_args) + ds = MultitaskFromSmilesDataModule(task_specific_args, featurization_n_jobs=0) ds.prepare_data() ds.setup() @@ -448,7 +448,7 @@ def test_datamodule_multiple_data_files(self): "task": {"task_level": "graph", "label_cols": ["score"], "smiles_col": "SMILES", **task_kwargs} } - ds = MultitaskFromSmilesDataModule(task_specific_args) + ds = MultitaskFromSmilesDataModule(task_specific_args, featurization_n_jobs=0) ds.prepare_data() ds.setup() @@ -461,7 +461,7 @@ def test_datamodule_multiple_data_files(self): "task": {"task_level": "graph", "label_cols": ["score"], "smiles_col": "SMILES", **task_kwargs} } - ds = MultitaskFromSmilesDataModule(task_specific_args) + ds = MultitaskFromSmilesDataModule(task_specific_args, featurization_n_jobs=0) ds.prepare_data() ds.setup() From cc91bfa72a0f7352f0491a81f8595cd1b71f2a57 Mon Sep 17 00:00:00 2001 From: WenkelF Date: Fri, 18 Aug 2023 13:40:38 -0400 Subject: [PATCH 14/14] Reformatting with black --- tests/test_datamodule.py | 32 ++++++++++++++++++++++++-------- 1 file changed, 24 insertions(+), 8 deletions(-) diff --git a/tests/test_datamodule.py b/tests/test_datamodule.py index 510658f50..2bc89200c 100644 --- a/tests/test_datamodule.py +++ b/tests/test_datamodule.py @@ -241,20 +241,32 @@ def test_caching(self): batch_from_disk = next(iter(cached_train_loader_from_disk)) # Features are the same - np.testing.assert_array_almost_equal(batch["features"].edge_index, batch_from_ram["features"].edge_index) - np.testing.assert_array_almost_equal(batch["features"].edge_index, batch_from_disk["features"].edge_index) + np.testing.assert_array_almost_equal( + batch["features"].edge_index, batch_from_ram["features"].edge_index + ) + np.testing.assert_array_almost_equal( + batch["features"].edge_index, batch_from_disk["features"].edge_index + ) assert batch["features"].num_nodes == batch_from_ram["features"].num_nodes assert batch["features"].num_nodes == batch_from_disk["features"].num_nodes - np.testing.assert_array_almost_equal(batch["features"].edge_weight, batch_from_ram["features"].edge_weight) - np.testing.assert_array_almost_equal(batch["features"].edge_weight, batch_from_disk["features"].edge_weight) + np.testing.assert_array_almost_equal( + batch["features"].edge_weight, batch_from_ram["features"].edge_weight + ) + np.testing.assert_array_almost_equal( + batch["features"].edge_weight, batch_from_disk["features"].edge_weight + ) np.testing.assert_array_almost_equal(batch["features"].feat, batch_from_ram["features"].feat) np.testing.assert_array_almost_equal(batch["features"].feat, batch_from_disk["features"].feat) - np.testing.assert_array_almost_equal(batch["features"].edge_feat, batch_from_ram["features"].edge_feat) - np.testing.assert_array_almost_equal(batch["features"].edge_feat, batch_from_disk["features"].edge_feat) + np.testing.assert_array_almost_equal( + batch["features"].edge_feat, batch_from_ram["features"].edge_feat + ) + np.testing.assert_array_almost_equal( + batch["features"].edge_feat, batch_from_disk["features"].edge_feat + ) np.testing.assert_array_almost_equal(batch["features"].batch, batch_from_ram["features"].batch) np.testing.assert_array_almost_equal(batch["features"].batch, batch_from_disk["features"].batch) @@ -263,8 +275,12 @@ def test_caching(self): np.testing.assert_array_almost_equal(batch["features"].ptr, batch_from_disk["features"].ptr) # Labels are the same - np.testing.assert_array_almost_equal(batch["labels"].graph_task_1, batch_from_ram["labels"].graph_task_1) - np.testing.assert_array_almost_equal(batch["labels"].graph_task_1, batch_from_disk["labels"].graph_task_1) + np.testing.assert_array_almost_equal( + batch["labels"].graph_task_1, batch_from_ram["labels"].graph_task_1 + ) + np.testing.assert_array_almost_equal( + batch["labels"].graph_task_1, batch_from_disk["labels"].graph_task_1 + ) np.testing.assert_array_almost_equal(batch["labels"].x, batch_from_ram["labels"].x) np.testing.assert_array_almost_equal(batch["labels"].x, batch_from_disk["labels"].x)