Skip to content
This repository has been archived by the owner on Oct 19, 2023. It is now read-only.

Commit

Permalink
Cool everything working!
Browse files Browse the repository at this point in the history
  • Loading branch information
jameschapman19 committed Mar 2, 2022
1 parent fbd7ce7 commit b761e8f
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 13 deletions.
23 changes: 11 additions & 12 deletions multiviewdata/torchdatasets/mnist.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ class Split_MNIST_Dataset(Dataset):
"""

def __init__(
self, root: str, mnist_type: str = "MNIST", train: bool = True, flatten: bool = True
self, root: str, mnist_type: str = "MNIST", train: bool = True, flatten: bool = True, download=False,
):
"""
:param root: Root directory of dataset
Expand All @@ -25,7 +25,7 @@ def __init__(
:param flatten: whether to flatten the data into array or use 2d images
"""

self.dataset = load_mnist(mnist_type, train, root)
self.dataset = load_mnist(mnist_type, train, root, download)
self.flatten = flatten

def __len__(self):
Expand All @@ -49,15 +49,15 @@ class Noisy_MNIST_Dataset(Dataset):
"""

def __init__(
self, root: str, mnist_type: str = "MNIST", train: bool = True, flatten: bool = True
self, root: str, mnist_type: str = "MNIST", train: bool = True, flatten: bool = True, download=False,
):
"""
:param root: Root directory of dataset
:param mnist_type: "MNIST", "FashionMNIST" or "KMNIST"
:param train: whether this is train or test
:param flatten: whether to flatten the data into array or use 2d images
"""
self.dataset = load_mnist(mnist_type, train, root)
self.dataset = load_mnist(mnist_type, train, root, download)
self.a_transform = torchvision.transforms.RandomRotation((-45, 45))
self.b_transform = transforms.Compose(
[
Expand Down Expand Up @@ -99,15 +99,15 @@ class Tangled_MNIST_Dataset(Dataset):
"""

def __init__(
self, root: str, mnist_type: str = "MNIST", train: bool = True, flatten: bool = True
self, root: str, mnist_type: str = "MNIST", train: bool = True, flatten: bool = True, download=False,
):
"""
:param root: Root directory of dataset
:param mnist_type: "MNIST", "FashionMNIST" or "KMNIST"
:param train: whether this is train or test
:param flatten: whether to flatten the data into array or use 2d images
"""
self.dataset = load_mnist(mnist_type, train, root)
self.dataset = load_mnist(mnist_type, train, root, download)
self.transform = torchvision.transforms.RandomRotation((-45, 45))
self.targets = self.dataset.targets
self.filtered_classes = []
Expand Down Expand Up @@ -138,13 +138,12 @@ def _add_mnist_noise(x):
return x


def load_mnist(mnist_type, train, root):
def load_mnist(mnist_type, train, root, download):
if mnist_type == "MNIST":
dataset = load_mnist(mnist_type, train)
datasets.MNIST(
dataset=datasets.MNIST(
root,
train=train,
download=True,
download=download,
transform=torchvision.transforms.Compose(
[torchvision.transforms.ToTensor()]
),
Expand All @@ -154,7 +153,7 @@ def load_mnist(mnist_type, train, root):
dataset = datasets.FashionMNIST(
root,
train=train,
download=True,
download=download,
transform=torchvision.transforms.Compose(
[torchvision.transforms.ToTensor()]
),
Expand All @@ -163,7 +162,7 @@ def load_mnist(mnist_type, train, root):
dataset = datasets.KMNIST(
root,
train=train,
download=True,
download=download,
transform=torchvision.transforms.Compose(
[
torchvision.transforms.ToTensor(),
Expand Down
3 changes: 2 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,5 @@ torchvision~=0.11.3
Pillow~=9.0.1
matplotlib~=3.5.1
torch~=1.10.2
scikit-learn~=1.0.2
scikit-learn~=1.0.2
h5py~=3.6.0

0 comments on commit b761e8f

Please sign in to comment.