diff --git a/test/e2e/mnist.py b/test/e2e/mnist.py index 4cfe0b43..5ac26665 100644 --- a/test/e2e/mnist.py +++ b/test/e2e/mnist.py @@ -20,7 +20,7 @@ from pytorch_lightning.callbacks.progress import TQDMProgressBar from torch import nn from torch.nn import functional as F -from torch.utils.data import DataLoader, random_split +from torch.utils.data import DataLoader, random_split, RandomSampler from torchmetrics import Accuracy from torchvision import transforms from torchvision.datasets import MNIST @@ -158,7 +158,7 @@ def setup(self, stage=None): ) def train_dataloader(self): - return DataLoader(self.mnist_train, batch_size=BATCH_SIZE) + return DataLoader(self.mnist_train, batch_size=BATCH_SIZE, sampler=RandomSampler(self.mnist_train, num_samples=1000)) def val_dataloader(self): return DataLoader(self.mnist_val, batch_size=BATCH_SIZE) @@ -178,10 +178,11 @@ def test_dataloader(self): trainer = Trainer( accelerator="auto", # devices=1 if torch.cuda.is_available() else None, # limiting got iPython runs - max_epochs=5, + max_epochs=3, callbacks=[TQDMProgressBar(refresh_rate=20)], num_nodes=int(os.environ.get("GROUP_WORLD_SIZE", 1)), devices=int(os.environ.get("LOCAL_WORLD_SIZE", 1)), + replace_sampler_ddp=False, strategy="ddp", ) diff --git a/test/odh/resources/mnist.py b/test/odh/resources/mnist.py index e88e8fc9..85d420f4 100644 --- a/test/odh/resources/mnist.py +++ b/test/odh/resources/mnist.py @@ -20,7 +20,7 @@ from pytorch_lightning.callbacks.progress import TQDMProgressBar from torch import nn from torch.nn import functional as F -from torch.utils.data import DataLoader, random_split +from torch.utils.data import DataLoader, random_split, RandomSampler from torchmetrics import Accuracy from torchvision import transforms from torchvision.datasets import MNIST @@ -158,7 +158,7 @@ def setup(self, stage=None): ) def train_dataloader(self): - return DataLoader(self.mnist_train, batch_size=BATCH_SIZE) + return DataLoader(self.mnist_train, batch_size=BATCH_SIZE, sampler=RandomSampler(self.mnist_train, num_samples=1000)) def val_dataloader(self): return DataLoader(self.mnist_val, batch_size=BATCH_SIZE) @@ -178,10 +178,11 @@ def test_dataloader(self): trainer = Trainer( accelerator="auto", # devices=1 if torch.cuda.is_available() else None, # limiting got iPython runs - max_epochs=2, + max_epochs=3, callbacks=[TQDMProgressBar(refresh_rate=20)], num_nodes=int(os.environ.get("GROUP_WORLD_SIZE", 1)), devices=int(os.environ.get("LOCAL_WORLD_SIZE", 1)), + replace_sampler_ddp=False, strategy="ddp", )