Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

U-net implementation #247

Merged
merged 8 commits into from
Sep 27, 2020
Merged
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
129 changes: 129 additions & 0 deletions pl_bolts/models/vision/unet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
import torch
import torch.nn as nn
import torch.nn.functional as F


class UNet(nn.Module):
"""
PyTorch Lightning implementation of `U-Net: Convolutional Networks for Biomedical Image Segmentation
<https://arxiv.org/abs/1505.04597>`_

Paper authors: Olaf Ronneberger, Philipp Fischer, Thomas Brox

Model implemented by:
- `Annika Brundyn <https://github.com/annikabrundyn>`_
- `Akshay Kulkarni <https://github.com/akshaykvnit>`_

.. warning:: Work in progress. This implementation is still being verified.

Args:
num_classes: Number of output classes required
num_layers: Number of layers in each side of U-net (default 5)
features_start: Number of features in first layer (default 64)
bilinear (bool): Whether to use bilinear interpolation or transposed convolutions (default) for upsampling.
"""

def __init__(
self,
num_classes: int,
num_layers: int = 5,
features_start: int = 64,
bilinear: bool = False
):
super().__init__()
self.num_layers = num_layers

layers = [DoubleConv(3, features_start)]

feats = features_start
for _ in range(num_layers - 1):
layers.append(Down(feats, feats * 2))
feats *= 2

for _ in range(num_layers - 1):
layers.append(Up(feats, feats // 2, bilinear))
feats //= 2

layers.append(nn.Conv2d(feats, num_classes, kernel_size=1))

self.layers = nn.ModuleList(layers)

def forward(self, x):
xi = [self.layers[0](x)]
# Down path
for layer in self.layers[1:self.num_layers]:
xi.append(layer(xi[-1]))
# Up path
for i, layer in enumerate(self.layers[self.num_layers:-1]):
xi[-1] = layer(xi[-1], xi[-2 - i])
return self.layers[-1](xi[-1])


class DoubleConv(nn.Module):
"""
[ Conv2d => BatchNorm (optional) => ReLU ] x 2
"""

def __init__(self, in_ch: int, out_ch: int):
super().__init__()
self.net = nn.Sequential(
nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1),
nn.BatchNorm2d(out_ch),
nn.ReLU(inplace=True),
nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1),
nn.BatchNorm2d(out_ch),
nn.ReLU(inplace=True)
)

def forward(self, x):
return self.net(x)


class Down(nn.Module):
"""
Downscale with MaxPool => DoubleConvolution block
"""

def __init__(self, in_ch: int, out_ch: int):
super().__init__()
self.net = nn.Sequential(
nn.MaxPool2d(kernel_size=2, stride=2),
DoubleConv(in_ch, out_ch)
)

def forward(self, x):
return self.net(x)


class Up(nn.Module):
"""
Upsampling (by either bilinear interpolation or transpose convolutions)
followed by concatenation of feature map from contracting path,
followed by DoubleConv.
"""

def __init__(self, in_ch: int, out_ch: int, bilinear: bool = False):
super().__init__()
self.upsample = None
if bilinear:
self.upsample = nn.Sequential(
nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True),
nn.Conv2d(in_ch, in_ch // 2, kernel_size=1),
)
else:
self.upsample = nn.ConvTranspose2d(in_ch, in_ch // 2, kernel_size=2, stride=2)

self.conv = DoubleConv(in_ch, out_ch)

def forward(self, x1, x2):
x1 = self.upsample(x1)

# Pad x1 to the size of x2
diff_h = x2.shape[2] - x1.shape[2]
diff_w = x2.shape[3] - x1.shape[3]

x1 = F.pad(x1, [diff_w // 2, diff_w - diff_w // 2, diff_h // 2, diff_h - diff_h // 2])

# Concatenate along the channels axis
x = torch.cat([x2, x1], dim=1)
return self.conv(x)