From b761e8fd3e577d874ba405d48d34410b76fec8c7 Mon Sep 17 00:00:00 2001 From: James Chapman Date: Wed, 2 Mar 2022 22:51:14 +0000 Subject: [PATCH] Cool everything working! --- multiviewdata/torchdatasets/mnist.py | 23 +++++++++++------------ requirements.txt | 3 ++- 2 files changed, 13 insertions(+), 13 deletions(-) diff --git a/multiviewdata/torchdatasets/mnist.py b/multiviewdata/torchdatasets/mnist.py index a4e7c37..2116b83 100644 --- a/multiviewdata/torchdatasets/mnist.py +++ b/multiviewdata/torchdatasets/mnist.py @@ -16,7 +16,7 @@ class Split_MNIST_Dataset(Dataset): """ def __init__( - self, root: str, mnist_type: str = "MNIST", train: bool = True, flatten: bool = True + self, root: str, mnist_type: str = "MNIST", train: bool = True, flatten: bool = True, download=False, ): """ :param root: Root directory of dataset @@ -25,7 +25,7 @@ def __init__( :param flatten: whether to flatten the data into array or use 2d images """ - self.dataset = load_mnist(mnist_type, train, root) + self.dataset = load_mnist(mnist_type, train, root, download) self.flatten = flatten def __len__(self): @@ -49,7 +49,7 @@ class Noisy_MNIST_Dataset(Dataset): """ def __init__( - self, root: str, mnist_type: str = "MNIST", train: bool = True, flatten: bool = True + self, root: str, mnist_type: str = "MNIST", train: bool = True, flatten: bool = True, download=False, ): """ :param root: Root directory of dataset @@ -57,7 +57,7 @@ def __init__( :param train: whether this is train or test :param flatten: whether to flatten the data into array or use 2d images """ - self.dataset = load_mnist(mnist_type, train, root) + self.dataset = load_mnist(mnist_type, train, root, download) self.a_transform = torchvision.transforms.RandomRotation((-45, 45)) self.b_transform = transforms.Compose( [ @@ -99,7 +99,7 @@ class Tangled_MNIST_Dataset(Dataset): """ def __init__( - self, root: str, mnist_type: str = "MNIST", train: bool = True, flatten: bool = True + self, root: str, mnist_type: str = "MNIST", train: bool = True, flatten: bool = True, download=False, ): """ :param root: Root directory of dataset @@ -107,7 +107,7 @@ def __init__( :param train: whether this is train or test :param flatten: whether to flatten the data into array or use 2d images """ - self.dataset = load_mnist(mnist_type, train, root) + self.dataset = load_mnist(mnist_type, train, root, download) self.transform = torchvision.transforms.RandomRotation((-45, 45)) self.targets = self.dataset.targets self.filtered_classes = [] @@ -138,13 +138,12 @@ def _add_mnist_noise(x): return x -def load_mnist(mnist_type, train, root): +def load_mnist(mnist_type, train, root, download): if mnist_type == "MNIST": - dataset = load_mnist(mnist_type, train) - datasets.MNIST( + dataset=datasets.MNIST( root, train=train, - download=True, + download=download, transform=torchvision.transforms.Compose( [torchvision.transforms.ToTensor()] ), @@ -154,7 +153,7 @@ def load_mnist(mnist_type, train, root): dataset = datasets.FashionMNIST( root, train=train, - download=True, + download=download, transform=torchvision.transforms.Compose( [torchvision.transforms.ToTensor()] ), @@ -163,7 +162,7 @@ def load_mnist(mnist_type, train, root): dataset = datasets.KMNIST( root, train=train, - download=True, + download=download, transform=torchvision.transforms.Compose( [ torchvision.transforms.ToTensor(), diff --git a/requirements.txt b/requirements.txt index d765336..c0dcdbb 100644 --- a/requirements.txt +++ b/requirements.txt @@ -4,4 +4,5 @@ torchvision~=0.11.3 Pillow~=9.0.1 matplotlib~=3.5.1 torch~=1.10.2 -scikit-learn~=1.0.2 \ No newline at end of file +scikit-learn~=1.0.2 +h5py~=3.6.0 \ No newline at end of file