Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

v0.8.0 #113

Merged
merged 45 commits into from
Jan 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
45 commits
Select commit Hold shift + click to select a range
54fd31a
:art: Applied black
actions-user Nov 2, 2023
4973755
:memo: Added attention demo
actions-user Nov 2, 2023
0b423a8
:sparkles: Added more shapes
actions-user Nov 5, 2023
f3adcf1
:alembic: Changed shape size
actions-user Nov 8, 2023
8a5a119
:alembic: Experimenting with attention
actions-user Nov 8, 2023
fb908ab
:memo: Working on demo
actions-user Dec 22, 2023
bd014b5
:sparkles: Added random seed argument
actions-user Dec 22, 2023
65f5fda
:memo: Updated demo
actions-user Dec 22, 2023
2eeb238
:memo: Updated demo
actions-user Dec 22, 2023
7525882
:memo: Updated demo
actions-user Dec 22, 2023
4796cfb
:memo: Updated demo
actions-user Dec 22, 2023
e3abad2
:recycle: Increased size of squares
actions-user Dec 23, 2023
d181b8d
:memo: Updated demo
actions-user Dec 23, 2023
b79d762
:memo: Updated docstring
actions-user Dec 23, 2023
f1aae3d
:memo: Updated demo
actions-user Dec 23, 2023
be5e075
:memo: Updated demo
actions-user Dec 23, 2023
adc6892
:memo: Updated docs
actions-user Jan 2, 2024
30cce83
:memo: Updated demo
actions-user Jan 2, 2024
d8021fc
:memo: Added rotations to demo
actions-user Jan 3, 2024
22a19e8
:wrench: Updated to next release version
actions-user Jan 3, 2024
c5036f7
:memo: Updated changelog
actions-user Jan 3, 2024
628cb39
:memo: Added rotations to demo
actions-user Jan 3, 2024
3e2ea6b
:memo: Added rotations to demo
actions-user Jan 3, 2024
cc938cf
:memo: Ren demo for longer
actions-user Jan 3, 2024
48cba2a
:sparkles: Added function for extracting encoded features to simple c…
actions-user Jan 3, 2024
57a1fae
:sparkles: Added get features method to ConvNet2d
actions-user Jan 3, 2024
bb1629d
:white_check_mark: Added test for calling get_features method
actions-user Jan 3, 2024
2bb6142
:memo: Updated demo
actions-user Jan 3, 2024
76e7995
:memo: Updated model and text
actions-user Jan 3, 2024
2b4f200
:memo: Updated demo with get_feats
actions-user Jan 6, 2024
faeb3e8
:memo: Updated demo with get_feats
actions-user Jan 6, 2024
d684804
:art: Applied black
actions-user Jan 6, 2024
4dacbcf
:white_check_mark: Added tests for get_features method call
actions-user Jan 6, 2024
3aa5115
:memo: Updateed Changelog
actions-user Jan 6, 2024
5f23b97
:memo: Updateed Changelog
actions-user Jan 6, 2024
cfc621a
:memo: Updateed Changelog
actions-user Jan 6, 2024
e6ae95d
Merge pull request #111 from jdenholm/conv-feats
jdenholm Jan 6, 2024
19c19f5
:memo: Updated changelog
actions-user Jan 6, 2024
0c09d3e
:rotating_light: Removed mypy warning
actions-user Jan 6, 2024
b8a464f
:rotating_light: Removed mypy warning
actions-user Jan 6, 2024
5534ba6
Merge pull request #112 from jdenholm/mil-attention
jdenholm Jan 6, 2024
621ebf6
:pencil2: Typo fix
actions-user Jan 7, 2024
0eb719a
:memo: Updated demo
actions-user Jan 7, 2024
2f01339
:sparkles: :memo: Added list of class names and docs on getting their…
actions-user Jan 7, 2024
586079e
:memo: Docstring tweak
actions-user Jan 7, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,7 +1,12 @@
## Version 0.8.0
- Add ``get_features`` method to ``SimpleConvNet2d`` for extracting encoded features.
- Add ``get_features`` method to ``ConvNet2d`` for extracting encoded features.
- Added a demo using a multiple-instance-learning attention model.

## Version 0.7.0
- Changed ``AutoEncoder2d`` demo to use ovarian histology images.

## Varsion 0.6.1
## Version 0.6.1
- Added residual blocks as optional block style to all relevant models.
- Changed the ``UNet`` demo to use a nuclei segmentation data set.

Expand Down
861 changes: 861 additions & 0 deletions demos/attention-mil-demo.ipynb

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion demos/autoencoder-demo.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -475,7 +475,7 @@
"from numpy import floor\n",
"\n",
"\n",
"total_steps = 2 * 10 ** 4\n",
"total_steps = 2 * 10**4\n",
"\n",
"interval = 500\n",
"\n",
Expand Down
107 changes: 95 additions & 12 deletions demos/conv_net_2d_mnist_demo.ipynb

Large diffs are not rendered by default.

114 changes: 12 additions & 102 deletions demos/shapes_dataset.ipynb

Large diffs are not rendered by default.

204 changes: 168 additions & 36 deletions demos/simple_conv_net_2d_mnist_demo.ipynb

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"

[project]
name = "torch_tools"
version = "0.7.0"
version = "0.8.0"
description = "PyTorch and other tools"
authors = [
{ name="Jim Denholm", email="j.denholm.2017@gmail.com" },
Expand Down
160 changes: 99 additions & 61 deletions src/torch_tools/datasets/_shapes_dataset.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,21 @@
"""Synthetic dataset object."""
from typing import Tuple
from typing import Tuple, Optional, Dict, Callable, List

from torch import from_numpy, Tensor # pylint: disable=no-name-in-module
from torch.utils.data import Dataset

from numpy import ones, float32, ndarray, array
from torchvision.transforms import Compose # type: ignore

from numpy import ndarray, array, where, full
from numpy.random import default_rng

from skimage.draw import disk, rectangle # pylint: disable=no-name-in-module
from skimage.morphology import star, square, octagon, disk


class ShapesDataset(Dataset):
"""Synthetic dataset which produces images withs spots and squares.

*Warning—*this dataset object is untested.
*Warning*—this dataset object is untested.

Parameters
----------
Expand All @@ -29,33 +31,67 @@ class ShapesDataset(Dataset):
The length of the data set.
image_size : int, optional
The length of the square images.
input_tfms : Compose, optional
A composition of transforms to apply to the input.
target_tfms : Compose, optional
A composition of transforms to apply to the target.
seed : int
Integer seed for numpy's default rng.

Notes
-----
The images have white backrounds and the shapes have randomly selected
RGB colours on [0, 1)^{3}.

To get the indices of each shape, use, for example

>>> data_set = ShapesDataset()
>>> spot_index = data_set.target_names.index("spot")
>>> star_index = data_set.target_names.index("star")

To print the classes as a list, use

>>> print(data_set.target_names)


"""

def __init__( # pylint: disable=too-many-arguments
self,
spots_prob: float = 0.5,
spot_prob: float = 0.5,
square_prob: float = 0.5,
num_spots: int = 10,
num_squares: int = 10,
star_prob: float = 0.5,
octagon_prob: float = 0.5,
num_shapes: int = 3,
length: int = 1000,
image_size: int = 256,
input_tfms: Optional[Compose] = None,
target_tfms: Optional[Compose] = None,
seed: int = 666,
):
"""Build ``ShapesDataset``."""
self._len = length
self._spot_prob = spots_prob
self._square_prob = square_prob
self._num_spots = num_spots
self._num_squares = num_squares
self._num_shapes = num_shapes

self._probs = {
"square": square_prob,
"spot": spot_prob,
"star": star_prob,
"octagon": octagon_prob,
}

self._img_size = image_size
self._x_tfms = input_tfms
self._y_tfms = target_tfms

self._rng = default_rng(seed=seed)

_rng = default_rng(seed=123)
_shapes: Dict[str, Callable] = {
"square": lambda x: square(2 * x),
"star": star,
"octagon": lambda x: octagon(x, x),
"spot": disk,
}

def __len__(self) -> int:
"""Return the length of the dataset.
Expand All @@ -78,9 +114,9 @@ def __getitem__(self, idx: int):

Returns
-------
Tensor
img : Tensor
An RGB image of shape (c, H, W).
Tensor
tgt : Tensor
Target vector:

— If there are no spots or squares, [0.0, 0.0]
Expand All @@ -89,71 +125,54 @@ def __getitem__(self, idx: int):
— If there are both, [1.0, 1.0]

"""
return self._create_image()
img, tgt = self._create_image()

def _add_spots(self, image: ndarray) -> bool:
if self._x_tfms is not None:
img = self._x_tfms(img)

if self._y_tfms is not None:
tgt = self._y_tfms(tgt)

return img, tgt

def _add_shape(self, image: ndarray, shape: str) -> bool:
"""Add spots to ``image``.

Parameters
----------
image : ndarray
RGB image.
shape : str
Name of the shape to include.


Returns
-------
include_spots : bool
Whether or not the spots were added.

"""
include_spots = self._rng.random() <= self._spot_prob
include_shape = self._rng.random() <= self._probs[shape]

if include_spots:
for _ in range(self._num_spots):
colour = self._rng.random(size=3)
if include_shape:
for _ in range(self._num_shapes):
colour = self._rng.random(size=(1, 3))
radius = self._img_size // 20
centre = self._rng.integers(
radius,
self._img_size - radius,
size=2,
)

rows, cols = disk(centre, radius)

image[rows, cols, :] = colour
shape_arr = self._shapes[shape](radius)

return include_spots

def _add_squares(self, image: ndarray) -> bool:
"""Add spots to ``image``.

Parameters
----------
image : ndarray
RGB image.

Returns
-------
include_squares : bool
Whether or no squares were included.

"""
include_squares = self._rng.random() <= self._square_prob

if include_squares:
for _ in range(self._num_spots):
colour = self._rng.random(size=3)
length = self._img_size // 20
origin = self._rng.integers(
# pylint: disable=unbalanced-tuple-unpacking
rows, cols = where(shape_arr == 1)
left, top = self._rng.integers(
0,
self._img_size - (2 * length),
self._img_size - len(shape_arr),
size=2,
)

rows, cols = rectangle(origin, origin + (2 * length))
rows, cols = rows + top, cols + left

image[rows, cols, :] = colour
image[rows, cols] = colour

return include_squares
return include_shape

def _create_image(self) -> Tuple[Tensor, Tensor]:
"""Create image.
Expand All @@ -171,12 +190,31 @@ def _create_image(self) -> Tuple[Tensor, Tensor]:
— If there are both, [1.0, 1.0]

"""
image = ones((self._img_size, self._img_size, 3), dtype=float32)
# image = ones((self._img_size, self._img_size, 3), dtype=float32)

image = full(
(self._img_size, self._img_size, 3),
fill_value=self._rng.random(size=(1, 3)),
)

spots = self._add_spots(image)
squares = self._add_squares(image)
targets = []
for key in self._shapes:
targets.append(self._add_shape(image, key))

return (
from_numpy(image).permute(2, 0, 1),
from_numpy(array([spots, squares])).float(),
from_numpy(image).permute(2, 0, 1).float(),
from_numpy(array(targets)).float(),
)

@property
def target_names(self) -> List[str]:
"""Return a list of target names order by their one-hot indices.

Returns
-------
List[str]
A list of the names of the shapes, ordered by their one-hot
indices.

"""
return list(self._shapes.keys())
17 changes: 17 additions & 0 deletions src/torch_tools/models/_conv_net_2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,23 @@ def forward(self, batch: Tensor, frozen_encoder: bool = False) -> Tensor:
pool_out = self.pool(encoder_out)
return self.dense_layers(pool_out)

def get_features(self, batch: Tensor) -> Tensor:
"""Return the features produced by the encoder and pool.

Parameters
----------
batch : Tensor
A mini-batch of image-like inputs.

Returns
-------
Tensor
The encoded features for the items in ``batch``.

"""
encoder_out = self.backbone(batch)
return self.pool(encoder_out)


def _conv_config(conv: Conv2d) -> Dict[str, Any]:
"""Return a dictionary with the `conv`'s instantiation arguments.
Expand Down
19 changes: 19 additions & 0 deletions src/torch_tools/models/_simple_conv_2d.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""A simple two-dimensional convolutional neural network."""
from typing import Optional, Dict, Any

from torch import Tensor
from torch.nn import Sequential, Flatten

from torch_tools.models._encoder_2d import Encoder2d
Expand Down Expand Up @@ -107,6 +108,24 @@ def __init__(
),
)

def get_features(self, batch: Tensor) -> Tensor:
"""Get the features produced by the encoder and adaptive poool.

Parameters
----------
batch : Tensor
A mini-batch of image-like inputs.

Returns
-------
Tensor
The features for each item in ``batch``.

"""
feats = self[0](batch)
feats = self[1](feats)
return self[2](feats)

_dn_args: Dict[str, Any] = {
"hidden_sizes": None,
"input_bnorm": False,
Expand Down
12 changes: 12 additions & 0 deletions tests/models/test_conv_net_2d_call.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,3 +64,15 @@ def test_call_works_with_all_kinds_of_pools():

model = ConvNet2d(out_feats=1, pool_style="avg-max-concat")
_ = model(rand(10, 3, 50, 50))


def test_the_returned_shapes_of_get_features_method():
"""Test the shapes returned by the ``get_features`` method."""
model = ConvNet2d(10, encoder_style="resnet18", pool_style="avg")
assert model.get_features(rand(10, 3, 100, 100)).shape == (10, 512)

model = ConvNet2d(10, encoder_style="resnet34", pool_style="avg")
assert model.get_features(rand(10, 3, 100, 100)).shape == (10, 512)

model = ConvNet2d(10, encoder_style="resnet50", pool_style="avg")
assert model.get_features(rand(10, 3, 100, 100)).shape == (10, 2048)
18 changes: 18 additions & 0 deletions tests/models/test_simple_conv_2d_call.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,3 +54,21 @@ def test_simple_conv_net_2d_call():
)

assert model(rand(10, in_chans, 100, 100)).shape == (10, out_feats)


def test_the_returned_shapes_of_get_features_method():
"""Test the shapes returned by the ``get_features`` method."""
for feats_start, num_blocks in zip([16, 32], [3, 4, 5]):
model = SimpleConvNet2d(
3,
10,
features_start=feats_start,
num_blocks=num_blocks,
)

batch = rand(10, 3, 100, 100)

assert model.get_features(batch).shape == (
10,
feats_start * (2 ** (num_blocks - 1)),
)
Loading