Skip to content

Commit

Permalink
U-net implementation (#247)
Browse files Browse the repository at this point in the history
* unet implementation

* unet

* clean up

* clean up

* update init

* init

* simple unet test

* simple unet test
  • Loading branch information
annikabrundyn authored Sep 27, 2020
1 parent 32139f4 commit 72e8be3
Show file tree
Hide file tree
Showing 5 changed files with 141 additions and 3 deletions.
2 changes: 1 addition & 1 deletion pl_bolts/datamodules/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from pl_bolts.datamodules.async_dataloader import AsynchronousLoader
from pl_bolts.datamodules.dummy_dataset import DummyDetectionDataset
from pl_bolts.datamodules.dummy_dataset import DummyDataset, DummyDetectionDataset

try:
from pl_bolts.datamodules.binary_mnist_datamodule import BinaryMNISTDataModule
Expand Down
1 change: 1 addition & 0 deletions pl_bolts/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,4 @@
from pl_bolts.models.regression import LinearRegression, LogisticRegression
from pl_bolts.models.vision import PixelCNN
from pl_bolts.models.vision.image_gpt.igpt_module import GPT2, ImageGPT
from pl_bolts.models.vision import UNet
1 change: 1 addition & 0 deletions pl_bolts/models/vision/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
from pl_bolts.models.vision.pixel_cnn import PixelCNN
from pl_bolts.models.vision.unet import UNet
129 changes: 129 additions & 0 deletions pl_bolts/models/vision/unet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
import torch
import torch.nn as nn
import torch.nn.functional as F


class UNet(nn.Module):
"""
PyTorch Lightning implementation of `U-Net: Convolutional Networks for Biomedical Image Segmentation
<https://arxiv.org/abs/1505.04597>`_
Paper authors: Olaf Ronneberger, Philipp Fischer, Thomas Brox
Model implemented by:
- `Annika Brundyn <https://github.com/annikabrundyn>`_
- `Akshay Kulkarni <https://github.com/akshaykvnit>`_
.. warning:: Work in progress. This implementation is still being verified.
Args:
num_classes: Number of output classes required
num_layers: Number of layers in each side of U-net (default 5)
features_start: Number of features in first layer (default 64)
bilinear (bool): Whether to use bilinear interpolation or transposed convolutions (default) for upsampling.
"""

def __init__(
self,
num_classes: int,
num_layers: int = 5,
features_start: int = 64,
bilinear: bool = False
):
super().__init__()
self.num_layers = num_layers

layers = [DoubleConv(3, features_start)]

feats = features_start
for _ in range(num_layers - 1):
layers.append(Down(feats, feats * 2))
feats *= 2

for _ in range(num_layers - 1):
layers.append(Up(feats, feats // 2, bilinear))
feats //= 2

layers.append(nn.Conv2d(feats, num_classes, kernel_size=1))

self.layers = nn.ModuleList(layers)

def forward(self, x):
xi = [self.layers[0](x)]
# Down path
for layer in self.layers[1:self.num_layers]:
xi.append(layer(xi[-1]))
# Up path
for i, layer in enumerate(self.layers[self.num_layers:-1]):
xi[-1] = layer(xi[-1], xi[-2 - i])
return self.layers[-1](xi[-1])


class DoubleConv(nn.Module):
"""
[ Conv2d => BatchNorm (optional) => ReLU ] x 2
"""

def __init__(self, in_ch: int, out_ch: int):
super().__init__()
self.net = nn.Sequential(
nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1),
nn.BatchNorm2d(out_ch),
nn.ReLU(inplace=True),
nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1),
nn.BatchNorm2d(out_ch),
nn.ReLU(inplace=True)
)

def forward(self, x):
return self.net(x)


class Down(nn.Module):
"""
Downscale with MaxPool => DoubleConvolution block
"""

def __init__(self, in_ch: int, out_ch: int):
super().__init__()
self.net = nn.Sequential(
nn.MaxPool2d(kernel_size=2, stride=2),
DoubleConv(in_ch, out_ch)
)

def forward(self, x):
return self.net(x)


class Up(nn.Module):
"""
Upsampling (by either bilinear interpolation or transpose convolutions)
followed by concatenation of feature map from contracting path,
followed by DoubleConv.
"""

def __init__(self, in_ch: int, out_ch: int, bilinear: bool = False):
super().__init__()
self.upsample = None
if bilinear:
self.upsample = nn.Sequential(
nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True),
nn.Conv2d(in_ch, in_ch // 2, kernel_size=1),
)
else:
self.upsample = nn.ConvTranspose2d(in_ch, in_ch // 2, kernel_size=2, stride=2)

self.conv = DoubleConv(in_ch, out_ch)

def forward(self, x1, x2):
x1 = self.upsample(x1)

# Pad x1 to the size of x2
diff_h = x2.shape[2] - x1.shape[2]
diff_w = x2.shape[3] - x1.shape[3]

x1 = F.pad(x1, [diff_w // 2, diff_w - diff_w // 2, diff_h // 2, diff_h - diff_h // 2])

# Concatenate along the channels axis
x = torch.cat([x2, x1], dim=1)
return self.conv(x)
11 changes: 9 additions & 2 deletions tests/models/test_vision_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,7 @@
import torch

from pl_bolts.datamodules import MNISTDataModule, FashionMNISTDataModule
from pl_bolts.models import GPT2, ImageGPT

from pl_bolts.models import GPT2, ImageGPT, UNet

def test_igpt(tmpdir):
pl.seed_everything(0)
Expand Down Expand Up @@ -47,3 +46,11 @@ def test_gpt2(tmpdir):
num_classes=10,
)
model(x)


def test_unet(tmpdir):
x = torch.rand(10, 3, 28, 28)
model = UNet(num_classes=2)
y = model(x)
assert y.shape == torch.Size([10, 2, 28, 28])

0 comments on commit 72e8be3

Please sign in to comment.