Skip to content

Commit

Permalink
Merge pull request #113 from jdenholm/dev
Browse files Browse the repository at this point in the history
v0.8.0
  • Loading branch information
jdenholm authored Jan 7, 2024
2 parents 1b7fb46 + 586079e commit dfe4181
Show file tree
Hide file tree
Showing 12 changed files with 1,309 additions and 214 deletions.
7 changes: 6 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,7 +1,12 @@
## Version 0.8.0
- Add ``get_features`` method to ``SimpleConvNet2d`` for extracting encoded features.
- Add ``get_features`` method to ``ConvNet2d`` for extracting encoded features.
- Added a demo using a multiple-instance-learning attention model.

## Version 0.7.0
- Changed ``AutoEncoder2d`` demo to use ovarian histology images.

## Varsion 0.6.1
## Version 0.6.1
- Added residual blocks as optional block style to all relevant models.
- Changed the ``UNet`` demo to use a nuclei segmentation data set.

Expand Down
861 changes: 861 additions & 0 deletions demos/attention-mil-demo.ipynb

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion demos/autoencoder-demo.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -475,7 +475,7 @@
"from numpy import floor\n",
"\n",
"\n",
"total_steps = 2 * 10 ** 4\n",
"total_steps = 2 * 10**4\n",
"\n",
"interval = 500\n",
"\n",
Expand Down
107 changes: 95 additions & 12 deletions demos/conv_net_2d_mnist_demo.ipynb

Large diffs are not rendered by default.

114 changes: 12 additions & 102 deletions demos/shapes_dataset.ipynb

Large diffs are not rendered by default.

204 changes: 168 additions & 36 deletions demos/simple_conv_net_2d_mnist_demo.ipynb

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"

[project]
name = "torch_tools"
version = "0.7.0"
version = "0.8.0"
description = "PyTorch and other tools"
authors = [
{ name="Jim Denholm", email="j.denholm.2017@gmail.com" },
Expand Down
160 changes: 99 additions & 61 deletions src/torch_tools/datasets/_shapes_dataset.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,21 @@
"""Synthetic dataset object."""
from typing import Tuple
from typing import Tuple, Optional, Dict, Callable, List

from torch import from_numpy, Tensor # pylint: disable=no-name-in-module
from torch.utils.data import Dataset

from numpy import ones, float32, ndarray, array
from torchvision.transforms import Compose # type: ignore

from numpy import ndarray, array, where, full
from numpy.random import default_rng

from skimage.draw import disk, rectangle # pylint: disable=no-name-in-module
from skimage.morphology import star, square, octagon, disk


class ShapesDataset(Dataset):
"""Synthetic dataset which produces images withs spots and squares.
*Warning—*this dataset object is untested.
*Warning*—this dataset object is untested.
Parameters
----------
Expand All @@ -29,33 +31,67 @@ class ShapesDataset(Dataset):
The length of the data set.
image_size : int, optional
The length of the square images.
input_tfms : Compose, optional
A composition of transforms to apply to the input.
target_tfms : Compose, optional
A composition of transforms to apply to the target.
seed : int
Integer seed for numpy's default rng.
Notes
-----
The images have white backrounds and the shapes have randomly selected
RGB colours on [0, 1)^{3}.
To get the indices of each shape, use, for example
>>> data_set = ShapesDataset()
>>> spot_index = data_set.target_names.index("spot")
>>> star_index = data_set.target_names.index("star")
To print the classes as a list, use
>>> print(data_set.target_names)
"""

def __init__( # pylint: disable=too-many-arguments
self,
spots_prob: float = 0.5,
spot_prob: float = 0.5,
square_prob: float = 0.5,
num_spots: int = 10,
num_squares: int = 10,
star_prob: float = 0.5,
octagon_prob: float = 0.5,
num_shapes: int = 3,
length: int = 1000,
image_size: int = 256,
input_tfms: Optional[Compose] = None,
target_tfms: Optional[Compose] = None,
seed: int = 666,
):
"""Build ``ShapesDataset``."""
self._len = length
self._spot_prob = spots_prob
self._square_prob = square_prob
self._num_spots = num_spots
self._num_squares = num_squares
self._num_shapes = num_shapes

self._probs = {
"square": square_prob,
"spot": spot_prob,
"star": star_prob,
"octagon": octagon_prob,
}

self._img_size = image_size
self._x_tfms = input_tfms
self._y_tfms = target_tfms

self._rng = default_rng(seed=seed)

_rng = default_rng(seed=123)
_shapes: Dict[str, Callable] = {
"square": lambda x: square(2 * x),
"star": star,
"octagon": lambda x: octagon(x, x),
"spot": disk,
}

def __len__(self) -> int:
"""Return the length of the dataset.
Expand All @@ -78,9 +114,9 @@ def __getitem__(self, idx: int):
Returns
-------
Tensor
img : Tensor
An RGB image of shape (c, H, W).
Tensor
tgt : Tensor
Target vector:
— If there are no spots or squares, [0.0, 0.0]
Expand All @@ -89,71 +125,54 @@ def __getitem__(self, idx: int):
— If there are both, [1.0, 1.0]
"""
return self._create_image()
img, tgt = self._create_image()

def _add_spots(self, image: ndarray) -> bool:
if self._x_tfms is not None:
img = self._x_tfms(img)

if self._y_tfms is not None:
tgt = self._y_tfms(tgt)

return img, tgt

def _add_shape(self, image: ndarray, shape: str) -> bool:
"""Add spots to ``image``.
Parameters
----------
image : ndarray
RGB image.
shape : str
Name of the shape to include.
Returns
-------
include_spots : bool
Whether or not the spots were added.
"""
include_spots = self._rng.random() <= self._spot_prob
include_shape = self._rng.random() <= self._probs[shape]

if include_spots:
for _ in range(self._num_spots):
colour = self._rng.random(size=3)
if include_shape:
for _ in range(self._num_shapes):
colour = self._rng.random(size=(1, 3))
radius = self._img_size // 20
centre = self._rng.integers(
radius,
self._img_size - radius,
size=2,
)

rows, cols = disk(centre, radius)

image[rows, cols, :] = colour
shape_arr = self._shapes[shape](radius)

return include_spots

def _add_squares(self, image: ndarray) -> bool:
"""Add spots to ``image``.
Parameters
----------
image : ndarray
RGB image.
Returns
-------
include_squares : bool
Whether or no squares were included.
"""
include_squares = self._rng.random() <= self._square_prob

if include_squares:
for _ in range(self._num_spots):
colour = self._rng.random(size=3)
length = self._img_size // 20
origin = self._rng.integers(
# pylint: disable=unbalanced-tuple-unpacking
rows, cols = where(shape_arr == 1)
left, top = self._rng.integers(
0,
self._img_size - (2 * length),
self._img_size - len(shape_arr),
size=2,
)

rows, cols = rectangle(origin, origin + (2 * length))
rows, cols = rows + top, cols + left

image[rows, cols, :] = colour
image[rows, cols] = colour

return include_squares
return include_shape

def _create_image(self) -> Tuple[Tensor, Tensor]:
"""Create image.
Expand All @@ -171,12 +190,31 @@ def _create_image(self) -> Tuple[Tensor, Tensor]:
— If there are both, [1.0, 1.0]
"""
image = ones((self._img_size, self._img_size, 3), dtype=float32)
# image = ones((self._img_size, self._img_size, 3), dtype=float32)

image = full(
(self._img_size, self._img_size, 3),
fill_value=self._rng.random(size=(1, 3)),
)

spots = self._add_spots(image)
squares = self._add_squares(image)
targets = []
for key in self._shapes:
targets.append(self._add_shape(image, key))

return (
from_numpy(image).permute(2, 0, 1),
from_numpy(array([spots, squares])).float(),
from_numpy(image).permute(2, 0, 1).float(),
from_numpy(array(targets)).float(),
)

@property
def target_names(self) -> List[str]:
"""Return a list of target names order by their one-hot indices.
Returns
-------
List[str]
A list of the names of the shapes, ordered by their one-hot
indices.
"""
return list(self._shapes.keys())
17 changes: 17 additions & 0 deletions src/torch_tools/models/_conv_net_2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,23 @@ def forward(self, batch: Tensor, frozen_encoder: bool = False) -> Tensor:
pool_out = self.pool(encoder_out)
return self.dense_layers(pool_out)

def get_features(self, batch: Tensor) -> Tensor:
"""Return the features produced by the encoder and pool.
Parameters
----------
batch : Tensor
A mini-batch of image-like inputs.
Returns
-------
Tensor
The encoded features for the items in ``batch``.
"""
encoder_out = self.backbone(batch)
return self.pool(encoder_out)


def _conv_config(conv: Conv2d) -> Dict[str, Any]:
"""Return a dictionary with the `conv`'s instantiation arguments.
Expand Down
19 changes: 19 additions & 0 deletions src/torch_tools/models/_simple_conv_2d.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""A simple two-dimensional convolutional neural network."""
from typing import Optional, Dict, Any

from torch import Tensor
from torch.nn import Sequential, Flatten

from torch_tools.models._encoder_2d import Encoder2d
Expand Down Expand Up @@ -107,6 +108,24 @@ def __init__(
),
)

def get_features(self, batch: Tensor) -> Tensor:
"""Get the features produced by the encoder and adaptive poool.
Parameters
----------
batch : Tensor
A mini-batch of image-like inputs.
Returns
-------
Tensor
The features for each item in ``batch``.
"""
feats = self[0](batch)
feats = self[1](feats)
return self[2](feats)

_dn_args: Dict[str, Any] = {
"hidden_sizes": None,
"input_bnorm": False,
Expand Down
12 changes: 12 additions & 0 deletions tests/models/test_conv_net_2d_call.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,3 +64,15 @@ def test_call_works_with_all_kinds_of_pools():

model = ConvNet2d(out_feats=1, pool_style="avg-max-concat")
_ = model(rand(10, 3, 50, 50))


def test_the_returned_shapes_of_get_features_method():
"""Test the shapes returned by the ``get_features`` method."""
model = ConvNet2d(10, encoder_style="resnet18", pool_style="avg")
assert model.get_features(rand(10, 3, 100, 100)).shape == (10, 512)

model = ConvNet2d(10, encoder_style="resnet34", pool_style="avg")
assert model.get_features(rand(10, 3, 100, 100)).shape == (10, 512)

model = ConvNet2d(10, encoder_style="resnet50", pool_style="avg")
assert model.get_features(rand(10, 3, 100, 100)).shape == (10, 2048)
18 changes: 18 additions & 0 deletions tests/models/test_simple_conv_2d_call.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,3 +54,21 @@ def test_simple_conv_net_2d_call():
)

assert model(rand(10, in_chans, 100, 100)).shape == (10, out_feats)


def test_the_returned_shapes_of_get_features_method():
"""Test the shapes returned by the ``get_features`` method."""
for feats_start, num_blocks in zip([16, 32], [3, 4, 5]):
model = SimpleConvNet2d(
3,
10,
features_start=feats_start,
num_blocks=num_blocks,
)

batch = rand(10, 3, 100, 100)

assert model.get_features(batch).shape == (
10,
feats_start * (2 ** (num_blocks - 1)),
)

0 comments on commit dfe4181

Please sign in to comment.