Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Caching logic improvement #432

Merged
merged 18 commits into from
Aug 18, 2023
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,8 @@ graphium-train --config-path [PATH] --config-name [CONFIG]
```
Thanks to the modular nature of `hydra` you can reuse many of our config settings for your own experiments with Graphium.

## Preparing the data in advance
The data preparation including the featurization (e.g., of molecules from smiles to pyg-compatible format) is embedded in the pipeline and will be performed when executing `graphium-train [...]`. However, when working with larger datasets, it is recommended to perform data preparation in advance using a machine with sufficient allocated memory (e.g., ~400GB in the case of `LargeMix`). The command `graphium-prepare-data datamodule.args.processed_graph_data_path=[path_to_cached_data]` will prepare the data and cache it in the indicated location `[path_to_cached_data]`. The prepared data can be used for training via `graphium-train [...] datamodule.args.processed_graph_data_path=[path_to_cached_data]`. Note that `datamodule.args.processed_graph_data_path` can also be specified at `expts/hydra_configs/`.

## First Time Running on IPUs
For new IPU developers this section helps provide some more explanation on how to set up an environment to use Graphcore IPUs with Graphium.
Expand Down
40 changes: 40 additions & 0 deletions graphium/cli/prepare_data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
import hydra
import timeit

from omegaconf import DictConfig, OmegaConf
from loguru import logger

from graphium.config._loader import load_datamodule, load_accelerator


@hydra.main(version_base=None, config_path="../../expts/hydra-configs", config_name="main")
def cli(cfg: DictConfig) -> None:
"""
CLI endpoint for preparing the data in advance.
"""
run_prepare_data(cfg)


def run_prepare_data(cfg: DictConfig) -> None:
"""
The main (pre-)training and fine-tuning loop.
"""

cfg = OmegaConf.to_container(cfg, resolve=True)

st = timeit.default_timer()

## == Instantiate all required objects from their respective configs ==
# Accelerator
cfg, accelerator_type = load_accelerator(cfg)
DomInvivo marked this conversation as resolved.
Show resolved Hide resolved

## Data-module
datamodule = load_datamodule(cfg, accelerator_type)
DomInvivo marked this conversation as resolved.
Show resolved Hide resolved

datamodule.prepare_data()

logger.info(f"Data preparation took {timeit.default_timer() - st:.2f} seconds.")


if __name__ == "__main__":
cli()
94 changes: 48 additions & 46 deletions graphium/data/datamodule.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import tempfile
from contextlib import redirect_stderr, redirect_stdout
from typing import Type, List, Dict, Union, Any, Callable, Optional, Tuple, Iterable, Literal
from os import PathLike as Path

from dataclasses import dataclass

Expand Down Expand Up @@ -135,6 +136,7 @@ def __init__(
self._predict_ds = None

self._data_is_prepared = False
self._data_is_cached = False

def prepare_data(self):
raise NotImplementedError()
Expand Down Expand Up @@ -932,6 +934,11 @@ def __init__(
)
self.data_hash = self.get_data_hash()

if self.processed_graph_data_path is not None:
if self._ready_to_load_all_from_file():
self._data_is_prepared = True
self._data_is_cached = True

def _parse_caching_args(self, processed_graph_data_path, dataloading_from):
"""
Parse the caching arguments, and raise errors if the arguments are invalid.
Expand Down Expand Up @@ -994,15 +1001,10 @@ def has_atoms_after_h_removal(smiles):
return has_atoms

if self._data_is_prepared:
logger.info("Data is already prepared. Skipping the preparation")
logger.info("Data is already prepared.")
self.get_label_statistics(self.processed_graph_data_path, self.data_hash, dataset=None)
return

if self.dataloading_from == "disk":
if self._ready_to_load_all_from_file():
self.get_label_statistics(self.processed_graph_data_path, self.data_hash, dataset=None)
self._data_is_prepared = True
return

"""Load all single-task dataframes."""
task_df = {}
for task, args in self.task_dataset_processing_params.items():
Expand Down Expand Up @@ -1172,6 +1174,7 @@ def has_atoms_after_h_removal(smiles):

if self.processed_graph_data_path is not None:
self._save_data_to_files()
DomInvivo marked this conversation as resolved.
Show resolved Hide resolved
self._data_is_cached = True

self._data_is_prepared = True

Expand All @@ -1191,21 +1194,29 @@ def setup(
labels_size = {}
labels_dtype = {}
if stage == "fit" or stage is None:
if self.dataloading_from == "disk":
processed_train_data_path = self._path_to_load_from_file("train")
assert self._data_ready_at_path(
processed_train_data_path
), "Loading from file + setup() called but training data not ready"
processed_val_data_path = self._path_to_load_from_file("val")
assert self._data_ready_at_path(
processed_val_data_path
), "Loading from file + setup() called but validation data not ready"
else:
processed_train_data_path = None
processed_val_data_path = None
# if self.dataloading_from == "disk":
# processed_train_data_path = self._path_to_load_from_file("train")
# assert self._data_ready_at_path(
# processed_train_data_path
# ), "Loading from file + setup() called but training data not ready"
# processed_val_data_path = self._path_to_load_from_file("val")
# assert self._data_ready_at_path(
# processed_val_data_path
# ), "Loading from file + setup() called but validation data not ready"
# else:
# processed_train_data_path = None
# processed_val_data_path = None

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Forgot to remove commented lines. Will do so shortly.

# if not self._data_is_setup:
if self.train_ds is None:
self.train_ds = self._make_multitask_dataset(
self.dataloading_from, "train", save_smiles_and_ids=save_smiles_and_ids
)
if self.val_ds is None:
self.val_ds = self._make_multitask_dataset(
self.dataloading_from, "val", save_smiles_and_ids=save_smiles_and_ids
)

self.train_ds = self._make_multitask_dataset("train", save_smiles_and_ids=save_smiles_and_ids)
self.val_ds = self._make_multitask_dataset("val", save_smiles_and_ids=save_smiles_and_ids)
logger.info(self.train_ds)
logger.info(self.val_ds)
labels_size.update(
Expand All @@ -1216,14 +1227,11 @@ def setup(
labels_dtype.update(self.val_ds.labels_dtype)

if stage == "test" or stage is None:
if self.dataloading_from == "disk":
processed_test_data_path = self._path_to_load_from_file("test")
assert self._data_ready_at_path(
processed_test_data_path
), "Loading from file + setup() called but test data not ready"
else:
processed_test_data_path = None
self.test_ds = self._make_multitask_dataset("test", save_smiles_and_ids=save_smiles_and_ids)
if self.test_ds is None:
self.test_ds = self._make_multitask_dataset(
self.dataloading_from, "test", save_smiles_and_ids=save_smiles_and_ids
)

logger.info(self.test_ds)

labels_size.update(self.test_ds.labels_size)
Expand All @@ -1241,9 +1249,10 @@ def setup(

def _make_multitask_dataset(
self,
dataloading_from: Literal["disk", "ram"],
stage: Literal["train", "val", "test"],
save_smiles_and_ids: bool,
processed_graph_data_path: Optional[str] = None,
# processed_graph_data_path: Optional[str] = None,
) -> Datasets.MultitaskDataset:
"""
Create a MultitaskDataset for the given stage using single task datasets
Expand All @@ -1270,18 +1279,7 @@ def _make_multitask_dataset(
else:
raise ValueError(f"Unknown stage {stage}")

if processed_graph_data_path is None:
processed_graph_data_path = self.processed_graph_data_path

# assert singletask_datasets is not None, "Single task datasets must exist to make multitask dataset"
if singletask_datasets is None:
assert processed_graph_data_path is not None
assert self._data_ready_at_path(
self._path_to_load_from_file(stage)
), "Trying to create multitask dataset without single-task datasets but data not ready"
files_ready = True
else:
files_ready = False
processed_graph_data_path = self.processed_graph_data_path

multitask_dataset = Datasets.MultitaskDataset(
singletask_datasets,
Expand All @@ -1292,16 +1290,17 @@ def _make_multitask_dataset(
about=about,
save_smiles_and_ids=save_smiles_and_ids,
data_path=self._path_to_load_from_file(stage) if processed_graph_data_path else None,
processed_graph_data_path=processed_graph_data_path,
files_ready=files_ready,
dataloading_from=dataloading_from,
data_is_cached=self._data_is_cached,
) # type: ignore

# calculate statistics for the train split and used for all splits normalization
if stage == "train":
self.get_label_statistics(
self.processed_graph_data_path, self.data_hash, multitask_dataset, train=True
)
if self.dataloading_from == "ram":
# Normalization has already been applied in cached data
if not self._data_is_prepared:
self.normalize_label(multitask_dataset, stage)

return multitask_dataset
Expand Down Expand Up @@ -1342,7 +1341,9 @@ def _save_data_to_files(self) -> None:
# At the moment, we need to merge the `SingleTaskDataset`'s into `MultitaskDataset`s in order to save to file
# This is because the combined labels need to be stored together. We can investigate not doing this if this is a problem
temp_datasets = {
stage: self._make_multitask_dataset(stage, save_smiles_and_ids=False, load_from_file=False)
stage: self._make_multitask_dataset(
dataloading_from="ram", stage=stage, save_smiles_and_ids=False
)
for stage in stages
}
for stage in stages:
Expand All @@ -1364,6 +1365,7 @@ def calculate_statistics(self, dataset: Datasets.MultitaskDataset, train: bool =
train: whether the dataset is the training set

"""

if self.task_norms and train:
for task in dataset.labels_size.keys():
# if the label type is graph_*, we need to stack them as the tensor shape is (num_labels, )
Expand Down
69 changes: 57 additions & 12 deletions graphium/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
import os
import numpy as np

from datamol import parallelized, parallelized_with_batches

import torch
from torch.utils.data.dataloader import Dataset
from torch_geometric.data import Data, Batch
Expand Down Expand Up @@ -147,7 +149,7 @@ def __init__(
about: str = "",
data_path: Optional[Union[str, os.PathLike]] = None,
dataloading_from: str = "ram",
files_ready: bool = False,
data_is_cached: bool = False,
):
r"""
This class holds the information for the multitask dataset.
Expand Down Expand Up @@ -176,7 +178,6 @@ def __init__(
files_ready: Whether the files to load from were prepared ahead of time
"""
super().__init__()
# self.datasets = datasets
self.n_jobs = n_jobs
self.backend = backend
self.featurization_batch_size = featurization_batch_size
Expand All @@ -185,14 +186,17 @@ def __init__(
self.data_path = data_path
self.dataloading_from = dataloading_from

if files_ready:
if dataloading_from != "disk":
raise ValueError(
"Files are ready to be loaded from disk, but `dataloading_from` is not set to `disk`"
)
logger.info(f"Dataloading from {dataloading_from.upper()}")

if data_is_cached:
self._load_metadata()
self.features = None
self.labels = None

if dataloading_from == "disk":
self.features = None
self.labels = None
elif dataloading_from == "ram":
logger.info("Transferring data from DISK to RAM...")
self.transfer_from_disk_to_ram()

else:
task = next(iter(datasets))
Expand All @@ -213,9 +217,50 @@ def __init__(
if self.features is not None:
self._num_nodes_list = get_num_nodes_per_graph(self.features)
self._num_edges_list = get_num_edges_per_graph(self.features)
if self.dataloading_from == "disk":
self.features = None
self.labels = None

def transfer_from_disk_to_ram(self, parallel_with_batches: bool = False):
"""
Function parallelizing transfer from DISK to RAM
"""

def transfer_mol_from_disk_to_ram(idx):
"""
Function transferring single mol from DISK to RAM
"""
data_dict = self.load_graph_from_index(idx)
mol_in_ram = {}
mol_in_ram.update({"features": data_dict["graph_with_features"]})
mol_in_ram.update({"labels": data_dict["labels"]})
if self.smiles is not None:
mol_in_ram.update({"smiles": data_dict["smiles"]})

return mol_in_ram

if parallel_with_batches and self.featurization_batch_size:
data_in_ram = parallelized_with_batches(
transfer_mol_from_disk_to_ram,
range(self.dataset_length),
batch_size=self.featurization_batch_size,
n_jobs=self.n_jobs,
backend=self.backend,
progress=self.progress,
tqdm_kwargs={"desc": "Transfer from DISK to RAM"},
)
else:
data_in_ram = parallelized(
transfer_mol_from_disk_to_ram,
range(self.dataset_length),
n_jobs=self.n_jobs,
backend=self.backend,
progress=self.progress,
tqdm_kwargs={"desc": "Transfer from DISK to RAM"},
)

self.features = [sample["features"] for sample in data_in_ram]
self.labels = [sample["labels"] for sample in data_in_ram]
self.smiles = None
if "smiles" in self.load_graph_from_index(0):
self.smiles = [sample["smiles"] for sample in data_in_ram]

def save_metadata(self, directory: str):
"""
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ dependencies = [
[project.scripts]
graphium = "graphium.cli.main:main_cli"
graphium-train = "graphium.cli.train_finetune:cli"
graphium-prepare-data = "graphium.cli.prepare_data:cli"

[project.urls]
Website = "https://graphium.datamol.io/"
Expand Down