You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I got 2x speed using this drop-in replacement extracted from here:
class _RepeatSampler(object):
""" Sampler that repeats forever.
Args:
sampler (Sampler)
"""
def __init__(self, sampler):
self.sampler = sampler
def __iter__(self):
while True:
yield from iter(self.sampler)
class FastDataLoader(torch.utils.data.dataloader.DataLoader):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
object.__setattr__(self, 'batch_sampler', _RepeatSampler(self.batch_sampler))
self.iterator = super().__iter__()
def __len__(self):
return len(self.batch_sampler.sampler)
def __iter__(self):
for i in range(len(self)):
yield next(self.iterator)
class GraphDataLoader:
"""PyTorch dataloader for batch-iterating over a set of graphs, generating the batched
graph and corresponding label tensor (if provided) of the said minibatch.
Parameters
----------
collate_fn : Function, default is None
The customized collate function. Will use the default collate
function if not given.
kwargs : dict
Arguments being passed to :py:class:`torch.utils.data.DataLoader`.
Examples
--------
To train a GNN for graph classification on a set of graphs in ``dataset`` (assume
the backend is PyTorch):
>>> dataloader = dgl.dataloading.GraphDataLoader(
... dataset, batch_size=1024, shuffle=True, drop_last=False, num_workers=4)
>>> for batched_graph, labels in dataloader:
... train_on(batched_graph, labels)
"""
collator_arglist = inspect.getfullargspec(GraphCollator).args
def __init__(self, dataset, collate_fn=None, **kwargs):
collator_kwargs = {}
dataloader_kwargs = {}
for k, v in kwargs.items():
if k in self.collator_arglist:
collator_kwargs[k] = v
else:
dataloader_kwargs[k] = v
if collate_fn is None:
self.collate = GraphCollator(**collator_kwargs).collate
else:
self.collate = collate_fn
self.dataloader = FastDataLoader(dataset=dataset,
collate_fn=self.collate,
**dataloader_kwargs)
def __iter__(self):
"""Return the iterator of the data loader."""
return iter(self.dataloader)
def __len__(self):
"""Return the number of batches of the data loader."""
return len(self.dataloader)
Looks like the PyTorch issue suggested that the overhead is due to the re-initialization of worker processes. Did you try persistent_workers option in newer PyTorch versions?
🚀 Feature
Consider adding FastDataloder to DGL library
#Wall time: 13.3 s
#Wall time: 6.72 s
Motivation
I got 2x speed using this drop-in replacement extracted from here:
Cons
Apparently it may break in some situations: Lightning-AI/pytorch-lightning#1506
And has some limitations: pytorch/pytorch#15849 (comment)
The text was updated successfully, but these errors were encountered: