Skip to content

Commit

Permalink
Remove install dep on torchvision
Browse files Browse the repository at this point in the history
  • Loading branch information
ethanwharris committed May 24, 2019
1 parent 8878a6b commit a22e9ed
Show file tree
Hide file tree
Showing 6 changed files with 14 additions and 8 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
### Deprecated
### Removed
- Removed the fluent decorator, just use return self
- Removed install dependency on `torchvision`, still required for some functionality
### Fixed
- Fixed bug where replay errored when train or val steps were None
- Fixed a bug where mock optimser wouldn't call it's closure
Expand Down
4 changes: 2 additions & 2 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
torch>=0.4
torchvision<0.3
numpy
scikit-learn
tqdm
Expand All @@ -8,4 +7,5 @@ visdom
livelossplot
mock
Pillow
matplotlib
matplotlib
torchvision<0.3
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,6 @@
description='A model training and variational auto-encoder library for pytorch',
long_description=long_description,
long_description_content_type='text/markdown',
install_requires=['torch>=0.4', 'torchvision<0.3', 'tqdm'],
install_requires=['torch>=0.4', 'tqdm'],
python_requires='>=2.7,!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*',
)
4 changes: 2 additions & 2 deletions torchbearer/callbacks/tensor_board.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@

import torch
import torch.nn.functional as F
import torchvision.utils as utils

import torchbearer
from torchbearer.callbacks import Callback
Expand Down Expand Up @@ -353,7 +352,7 @@ def on_end(self, state):

class TensorBoardImages(AbstractTensorBoard):
"""The TensorBoardImages callback will write a selection of images from the validation pass to tensorboard using the
TensorboardX library and torchvision.utils.make_grid. Images are selected from the given key and saved to the given
TensorboardX library and torchvision.utils.make_grid (requires torchvision). Images are selected from the given key and saved to the given
path. Full name of image sub directory will be model name + _ + comment.
Args:
Expand Down Expand Up @@ -404,6 +403,7 @@ def __init__(self, log_dir='./logs',

def on_step_validation(self, state):
if not self.done:
import torchvision.utils as utils
data = state[self.key].clone()

if len(data.size()) == 3:
Expand Down
9 changes: 7 additions & 2 deletions torchbearer/variational/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,10 @@
import numpy as np
from PIL import Image
from torch.utils.data import Dataset
from torchvision.datasets.folder import has_file_allowed_extension, default_loader, IMG_EXTENSIONS


def make_dataset(dir, extensions):
from torchvision.datasets.folder import has_file_allowed_extension
images = []
for root, _, fnames in sorted(os.walk(dir)):
for fname in sorted(fnames):
Expand All @@ -20,7 +20,7 @@ def make_dataset(dir, extensions):


class SimpleImageFolder(Dataset):
def __init__(self, root, loader=default_loader, extensions=IMG_EXTENSIONS, transform=None, target_transform=None):
def __init__(self, root, loader=None, extensions=None, transform=None, target_transform=None):
"""
Simple image folder dataset that loads all images from inside a folder and returns items in (image, image) tuple
Expand All @@ -31,6 +31,10 @@ def __init__(self, root, loader=default_loader, extensions=IMG_EXTENSIONS, trans
transform (``Transform``, optional): A function/transform that takes in an PIL image and returns a transformed version. E.g, ``transforms.RandomCrop``
target_transform (``Transform``, optional): A function/transform that takes in an PIL image and returns a transformed version. E.g, ``transforms.RandomCrop``
"""
from torchvision.datasets.folder import default_loader, IMG_EXTENSIONS
loader = default_loader if loader is None else loader
extensions = IMG_EXTENSIONS if extensions is None else extensions

samples = make_dataset(root, extensions)

self.root = root
Expand Down Expand Up @@ -91,6 +95,7 @@ def __init__(self, root, as_npy=False, transform=None):
as_npy (bool, optional): If True, assume images are stored in numpy arrays. Else assume a standard image format
transform (``Transform``, optional): A function/transform that takes in an PIL image and returns a transformed version. E.g, ``transforms.RandomCrop``
"""
from torchvision.datasets.folder import default_loader, IMG_EXTENSIONS
if as_npy:
loader = self.npy_loader
extensions = ['npy']
Expand Down
2 changes: 1 addition & 1 deletion torchbearer/variational/visualisation.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import torch
from torchvision.utils import save_image

import torchbearer as tb
import torchbearer.callbacks as c
Expand Down Expand Up @@ -110,6 +109,7 @@ def vis(self, state):
raise NotImplementedError

def _save_walk(self, tensor):
from torchvision.utils import save_image
save_image(tensor, self.file, self.row_size, normalize=True, pad_value=1)


Expand Down

0 comments on commit a22e9ed

Please sign in to comment.