From 92de2824bc4bfe8f0dee67cecdb65565749a82a5 Mon Sep 17 00:00:00 2001 From: Chufan Gao Date: Tue, 29 Dec 2020 12:32:11 -0500 Subject: [PATCH] removed_trailing_whitespace --- dsm/datasets.py | 12 ++++++------ dsm/dsm_torch.py | 9 ++++----- 2 files changed, 10 insertions(+), 11 deletions(-) diff --git a/dsm/datasets.py b/dsm/datasets.py index 19a8cd0..e3b4862 100644 --- a/dsm/datasets.py +++ b/dsm/datasets.py @@ -197,10 +197,10 @@ def _load_support_dataset(): def _load_mnist(): """Helper function to load and preprocess the MNIST dataset. - The MNIST database of handwritten digits, available from this page, has a - training set of 60,000 examples, and a test set of 10,000 examples. - It is a good database for people who want to try learning techniques and - pattern recognition methods on real-world data while spending minimal + The MNIST database of handwritten digits, available from this page, has a + training set of 60,000 examples, and a test set of 10,000 examples. + It is a good database for people who want to try learning techniques and + pattern recognition methods on real-world data while spending minimal efforts on preprocessing and formatting [1]. Please refer to http://yann.lecun.com/exdb/mnist/. @@ -208,13 +208,13 @@ def _load_mnist(): References ---------- - [1]: LeCun, Y. (1998). The MNIST database of handwritten digits. + [1]: LeCun, Y. (1998). The MNIST database of handwritten digits. http://yann.lecun.com/exdb/mnist/. """ - train = torchvision.datasets.MNIST(root='datasets/', + train = torchvision.datasets.MNIST(root='datasets/', train=True, download=True) x = train.data.numpy() x = np.expand_dims(x, 1).astype(float) diff --git a/dsm/dsm_torch.py b/dsm/dsm_torch.py index 44e627f..7f42657 100644 --- a/dsm/dsm_torch.py +++ b/dsm/dsm_torch.py @@ -35,12 +35,11 @@ import torch.nn as nn import torch -import torchvision import numpy as np __pdoc__ = {} -for clsn in ['DeepSurvivalMachinesTorch', +for clsn in ['DeepSurvivalMachinesTorch', 'DeepRecurrentSurvivalMachinesTorch']: for membr in ['training', 'dump_patches']: @@ -368,7 +367,7 @@ def create_conv_representation(inputdim, hidden, typ='ConvNet'): if typ == 'ConvNet': inputdim = np.squeeze(inputdim) - linear_dim = ((((inputdim-2) // 2) - 2) // 2) ** 2 + linear_dim = ((((inputdim-2) // 2) - 2) // 2) ** 2 linear_dim *= 16 embedding = nn.Sequential( nn.Conv2d(1, 6, 3), @@ -474,8 +473,8 @@ def __init__(self, inputdim, k, typ='ConvNet', nn.Linear(hidden, k, bias=True) ) for r in range(self.risks)}) - self.embedding = create_conv_representation(inputdim=inputdim, - hidden=hidden, + self.embedding = create_conv_representation(inputdim=inputdim, + hidden=hidden, typ='ConvNet') def forward(self, x, risk='1'):