From a22e9ed7341eba0524a7ebd0968665942de66f29 Mon Sep 17 00:00:00 2001 From: Ethan Harris Date: Fri, 24 May 2019 08:08:10 +0100 Subject: [PATCH] Remove install dep on torchvision --- CHANGELOG.md | 1 + requirements.txt | 4 ++-- setup.py | 2 +- torchbearer/callbacks/tensor_board.py | 4 ++-- torchbearer/variational/datasets.py | 9 +++++++-- torchbearer/variational/visualisation.py | 2 +- 6 files changed, 14 insertions(+), 8 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index b82b9adc..cd5aae78 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -22,6 +22,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Deprecated ### Removed - Removed the fluent decorator, just use return self +- Removed install dependency on `torchvision`, still required for some functionality ### Fixed - Fixed bug where replay errored when train or val steps were None - Fixed a bug where mock optimser wouldn't call it's closure diff --git a/requirements.txt b/requirements.txt index df69e53b..553bdad5 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,4 @@ torch>=0.4 -torchvision<0.3 numpy scikit-learn tqdm @@ -8,4 +7,5 @@ visdom livelossplot mock Pillow -matplotlib \ No newline at end of file +matplotlib +torchvision<0.3 \ No newline at end of file diff --git a/setup.py b/setup.py index 34e065ea..f0cb1a83 100644 --- a/setup.py +++ b/setup.py @@ -24,6 +24,6 @@ description='A model training and variational auto-encoder library for pytorch', long_description=long_description, long_description_content_type='text/markdown', - install_requires=['torch>=0.4', 'torchvision<0.3', 'tqdm'], + install_requires=['torch>=0.4', 'tqdm'], python_requires='>=2.7,!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*', ) diff --git a/torchbearer/callbacks/tensor_board.py b/torchbearer/callbacks/tensor_board.py index 3c3cf335..4f087126 100644 --- a/torchbearer/callbacks/tensor_board.py +++ b/torchbearer/callbacks/tensor_board.py @@ -3,7 +3,6 @@ import torch import torch.nn.functional as F -import torchvision.utils as utils import torchbearer from torchbearer.callbacks import Callback @@ -353,7 +352,7 @@ def on_end(self, state): class TensorBoardImages(AbstractTensorBoard): """The TensorBoardImages callback will write a selection of images from the validation pass to tensorboard using the - TensorboardX library and torchvision.utils.make_grid. Images are selected from the given key and saved to the given + TensorboardX library and torchvision.utils.make_grid (requires torchvision). Images are selected from the given key and saved to the given path. Full name of image sub directory will be model name + _ + comment. Args: @@ -404,6 +403,7 @@ def __init__(self, log_dir='./logs', def on_step_validation(self, state): if not self.done: + import torchvision.utils as utils data = state[self.key].clone() if len(data.size()) == 3: diff --git a/torchbearer/variational/datasets.py b/torchbearer/variational/datasets.py index 05bf9790..a8bce021 100644 --- a/torchbearer/variational/datasets.py +++ b/torchbearer/variational/datasets.py @@ -5,10 +5,10 @@ import numpy as np from PIL import Image from torch.utils.data import Dataset -from torchvision.datasets.folder import has_file_allowed_extension, default_loader, IMG_EXTENSIONS def make_dataset(dir, extensions): + from torchvision.datasets.folder import has_file_allowed_extension images = [] for root, _, fnames in sorted(os.walk(dir)): for fname in sorted(fnames): @@ -20,7 +20,7 @@ def make_dataset(dir, extensions): class SimpleImageFolder(Dataset): - def __init__(self, root, loader=default_loader, extensions=IMG_EXTENSIONS, transform=None, target_transform=None): + def __init__(self, root, loader=None, extensions=None, transform=None, target_transform=None): """ Simple image folder dataset that loads all images from inside a folder and returns items in (image, image) tuple @@ -31,6 +31,10 @@ def __init__(self, root, loader=default_loader, extensions=IMG_EXTENSIONS, trans transform (``Transform``, optional): A function/transform that takes in an PIL image and returns a transformed version. E.g, ``transforms.RandomCrop`` target_transform (``Transform``, optional): A function/transform that takes in an PIL image and returns a transformed version. E.g, ``transforms.RandomCrop`` """ + from torchvision.datasets.folder import default_loader, IMG_EXTENSIONS + loader = default_loader if loader is None else loader + extensions = IMG_EXTENSIONS if extensions is None else extensions + samples = make_dataset(root, extensions) self.root = root @@ -91,6 +95,7 @@ def __init__(self, root, as_npy=False, transform=None): as_npy (bool, optional): If True, assume images are stored in numpy arrays. Else assume a standard image format transform (``Transform``, optional): A function/transform that takes in an PIL image and returns a transformed version. E.g, ``transforms.RandomCrop`` """ + from torchvision.datasets.folder import default_loader, IMG_EXTENSIONS if as_npy: loader = self.npy_loader extensions = ['npy'] diff --git a/torchbearer/variational/visualisation.py b/torchbearer/variational/visualisation.py index ca047993..5a8c0d41 100644 --- a/torchbearer/variational/visualisation.py +++ b/torchbearer/variational/visualisation.py @@ -1,5 +1,4 @@ import torch -from torchvision.utils import save_image import torchbearer as tb import torchbearer.callbacks as c @@ -110,6 +109,7 @@ def vis(self, state): raise NotImplementedError def _save_walk(self, tensor): + from torchvision.utils import save_image save_image(tensor, self.file, self.row_size, normalize=True, pad_value=1)