-
Notifications
You must be signed in to change notification settings - Fork 323
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* unet implementation * unet * clean up * clean up * update init * init * simple unet test * simple unet test
- Loading branch information
1 parent
32139f4
commit 72e8be3
Showing
5 changed files
with
141 additions
and
3 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1,2 @@ | ||
from pl_bolts.models.vision.pixel_cnn import PixelCNN | ||
from pl_bolts.models.vision.unet import UNet |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,129 @@ | ||
import torch | ||
import torch.nn as nn | ||
import torch.nn.functional as F | ||
|
||
|
||
class UNet(nn.Module): | ||
""" | ||
PyTorch Lightning implementation of `U-Net: Convolutional Networks for Biomedical Image Segmentation | ||
<https://arxiv.org/abs/1505.04597>`_ | ||
Paper authors: Olaf Ronneberger, Philipp Fischer, Thomas Brox | ||
Model implemented by: | ||
- `Annika Brundyn <https://github.com/annikabrundyn>`_ | ||
- `Akshay Kulkarni <https://github.com/akshaykvnit>`_ | ||
.. warning:: Work in progress. This implementation is still being verified. | ||
Args: | ||
num_classes: Number of output classes required | ||
num_layers: Number of layers in each side of U-net (default 5) | ||
features_start: Number of features in first layer (default 64) | ||
bilinear (bool): Whether to use bilinear interpolation or transposed convolutions (default) for upsampling. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
num_classes: int, | ||
num_layers: int = 5, | ||
features_start: int = 64, | ||
bilinear: bool = False | ||
): | ||
super().__init__() | ||
self.num_layers = num_layers | ||
|
||
layers = [DoubleConv(3, features_start)] | ||
|
||
feats = features_start | ||
for _ in range(num_layers - 1): | ||
layers.append(Down(feats, feats * 2)) | ||
feats *= 2 | ||
|
||
for _ in range(num_layers - 1): | ||
layers.append(Up(feats, feats // 2, bilinear)) | ||
feats //= 2 | ||
|
||
layers.append(nn.Conv2d(feats, num_classes, kernel_size=1)) | ||
|
||
self.layers = nn.ModuleList(layers) | ||
|
||
def forward(self, x): | ||
xi = [self.layers[0](x)] | ||
# Down path | ||
for layer in self.layers[1:self.num_layers]: | ||
xi.append(layer(xi[-1])) | ||
# Up path | ||
for i, layer in enumerate(self.layers[self.num_layers:-1]): | ||
xi[-1] = layer(xi[-1], xi[-2 - i]) | ||
return self.layers[-1](xi[-1]) | ||
|
||
|
||
class DoubleConv(nn.Module): | ||
""" | ||
[ Conv2d => BatchNorm (optional) => ReLU ] x 2 | ||
""" | ||
|
||
def __init__(self, in_ch: int, out_ch: int): | ||
super().__init__() | ||
self.net = nn.Sequential( | ||
nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1), | ||
nn.BatchNorm2d(out_ch), | ||
nn.ReLU(inplace=True), | ||
nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1), | ||
nn.BatchNorm2d(out_ch), | ||
nn.ReLU(inplace=True) | ||
) | ||
|
||
def forward(self, x): | ||
return self.net(x) | ||
|
||
|
||
class Down(nn.Module): | ||
""" | ||
Downscale with MaxPool => DoubleConvolution block | ||
""" | ||
|
||
def __init__(self, in_ch: int, out_ch: int): | ||
super().__init__() | ||
self.net = nn.Sequential( | ||
nn.MaxPool2d(kernel_size=2, stride=2), | ||
DoubleConv(in_ch, out_ch) | ||
) | ||
|
||
def forward(self, x): | ||
return self.net(x) | ||
|
||
|
||
class Up(nn.Module): | ||
""" | ||
Upsampling (by either bilinear interpolation or transpose convolutions) | ||
followed by concatenation of feature map from contracting path, | ||
followed by DoubleConv. | ||
""" | ||
|
||
def __init__(self, in_ch: int, out_ch: int, bilinear: bool = False): | ||
super().__init__() | ||
self.upsample = None | ||
if bilinear: | ||
self.upsample = nn.Sequential( | ||
nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True), | ||
nn.Conv2d(in_ch, in_ch // 2, kernel_size=1), | ||
) | ||
else: | ||
self.upsample = nn.ConvTranspose2d(in_ch, in_ch // 2, kernel_size=2, stride=2) | ||
|
||
self.conv = DoubleConv(in_ch, out_ch) | ||
|
||
def forward(self, x1, x2): | ||
x1 = self.upsample(x1) | ||
|
||
# Pad x1 to the size of x2 | ||
diff_h = x2.shape[2] - x1.shape[2] | ||
diff_w = x2.shape[3] - x1.shape[3] | ||
|
||
x1 = F.pad(x1, [diff_w // 2, diff_w - diff_w // 2, diff_h // 2, diff_h - diff_h // 2]) | ||
|
||
# Concatenate along the channels axis | ||
x = torch.cat([x2, x1], dim=1) | ||
return self.conv(x) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters