Skip to content

Commit

Permalink
add length and DDP support to MolIterDataset
Browse files Browse the repository at this point in the history
  • Loading branch information
drewnutt committed Jul 22, 2024
1 parent 4674cfa commit 6410210
Showing 1 changed file with 22 additions and 6 deletions.
28 changes: 22 additions & 6 deletions python/torch_bindings.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
import math
import torch
from torch.distributed import get_rank, get_world_size
from torch.utils.data import DistributedSampler
import molgrid as mg
import types
from itertools import islice
Expand Down Expand Up @@ -323,13 +326,26 @@ def example_to_tensor(self, ex):

def __iter__(self):
worker_info = torch.utils.data.get_worker_info()
if worker_info is None:
return self.generate()
dataset = worker_info.dataset
worker_id = worker_info.id
n_workers = worker_info.num_workers
worker_id = worker_info.id if worker_info is not None else 0
n_workers = worker_info.num_workers if worker_info is not None else 1

world_size = get_world_size()
if world_size == 1:
return islice(self.generate(), worker_id, None, n_workers)
process_rank = get_rank()

return islice(self.generate(), process_rank * n_workers + worker_id, None, n_workers * world_size)

def __len__(self):
settings = self.examples.settings()
batch_size = settings.default_batch_size
if settings.iteration_scheme == mg.IterationScheme.SmallEpoch:
return self.examples.small_epoch_size() // batch_size
elif settings.iteration_scheme == mg.IterationScheme.LargeEpoch:
return math.ceil(self.examples.large_epoch_size() / batch_size)
else:
NotImplementedError('Iteration scheme %s not supported'.format(itr_scheme))

return islice(self.generate(), worker_id, None, n_workers)

def __getstate__(self):
settings = self.examples.settings()
Expand Down

0 comments on commit 6410210

Please sign in to comment.