From 641021085100433b09a2e627cf69162ce674f41f Mon Sep 17 00:00:00 2001 From: Drew McNutt Date: Mon, 22 Jul 2024 15:51:48 -0400 Subject: [PATCH] add length and DDP support to MolIterDataset --- python/torch_bindings.py | 28 ++++++++++++++++++++++------ 1 file changed, 22 insertions(+), 6 deletions(-) diff --git a/python/torch_bindings.py b/python/torch_bindings.py index 275c9c4..d99ef1f 100644 --- a/python/torch_bindings.py +++ b/python/torch_bindings.py @@ -1,4 +1,7 @@ +import math import torch +from torch.distributed import get_rank, get_world_size +from torch.utils.data import DistributedSampler import molgrid as mg import types from itertools import islice @@ -323,13 +326,26 @@ def example_to_tensor(self, ex): def __iter__(self): worker_info = torch.utils.data.get_worker_info() - if worker_info is None: - return self.generate() - dataset = worker_info.dataset - worker_id = worker_info.id - n_workers = worker_info.num_workers + worker_id = worker_info.id if worker_info is not None else 0 + n_workers = worker_info.num_workers if worker_info is not None else 1 + + world_size = get_world_size() + if world_size == 1: + return islice(self.generate(), worker_id, None, n_workers) + process_rank = get_rank() + + return islice(self.generate(), process_rank * n_workers + worker_id, None, n_workers * world_size) + + def __len__(self): + settings = self.examples.settings() + batch_size = settings.default_batch_size + if settings.iteration_scheme == mg.IterationScheme.SmallEpoch: + return self.examples.small_epoch_size() // batch_size + elif settings.iteration_scheme == mg.IterationScheme.LargeEpoch: + return math.ceil(self.examples.large_epoch_size() / batch_size) + else: + NotImplementedError('Iteration scheme %s not supported'.format(itr_scheme)) - return islice(self.generate(), worker_id, None, n_workers) def __getstate__(self): settings = self.examples.settings()