Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Revision models.vision.unet, models.vision.segmentation #880

Merged
merged 11 commits into from
Sep 19, 2022
69 changes: 47 additions & 22 deletions pl_bolts/models/vision/segmentation.py
Original file line number Diff line number Diff line change
@@ -1,45 +1,71 @@
from argparse import ArgumentParser
from typing import Any, Dict, Optional

import torch
from pytorch_lightning import LightningModule, Trainer, seed_everything
from torch import Tensor
from torch.nn import functional as F

from pl_bolts.models.vision.unet import UNet
from pl_bolts.utils.stability import under_review


@under_review()
class SemSegment(LightningModule):
"""Basic model for semantic segmentation. Uses UNet architecture by default.

The default parameters in this model are for the KITTI dataset. Note, if you'd like to use this model as is,
you will first need to download the KITTI dataset yourself. You can download the dataset `here.
<http://www.cvlibs.net/datasets/kitti/eval_semseg.php?benchmark=semantics2015>`_

Implemented by:

- `Annika Brundyn <https://github.com/annikabrundyn>`_

Example::

from pl_bolts.models.vision import SemSegment

model = SemSegment(num_classes=19)
dm = KittiDataModule(data_dir='/path/to/kitti/')

Trainer().fit(model, datamodule=dm)

Example CLI::

# KITTI
python segmentation.py --data_dir /path/to/kitti/ --accelerator=gpu
"""

def __init__(
self,
lr: float = 0.01,
num_classes: int = 19,
num_layers: int = 5,
features_start: int = 64,
bilinear: bool = False,
ignore_index: Optional[int] = 250,
lr: float = 0.01,
**kwargs: Any
):
"""Basic model for semantic segmentation. Uses UNet architecture by default.

The default parameters in this model are for the KITTI dataset. Note, if you'd like to use this model as is,
you will first need to download the KITTI dataset yourself. You can download the dataset `here.
<http://www.cvlibs.net/datasets/kitti/eval_semseg.php?benchmark=semantics2015>`_

Implemented by:

- `Annika Brundyn <https://github.com/annikabrundyn>`_

"""
Args:
num_classes: number of output classes (default 19)
num_layers: number of layers in each side of U-net (default 5)
features_start: number of features in first layer (default 64)
bilinear: whether to use bilinear interpolation (True) or transposed convolutions (default) for upsampling.
lr: learning (default 0.01)
ignore_index: target value to be ignored in cross_entropy (default 250)
lr: learning rate (default 0.01)
"""

super().__init__()

self.num_classes = num_classes
self.num_layers = num_layers
self.features_start = features_start
self.bilinear = bilinear
if ignore_index is None:
# set ignore_index to default value of F.cross_entropy if it is None.
self.ignore_index = -100
else:
self.ignore_index = ignore_index
self.lr = lr

self.net = UNet(
Expand All @@ -49,24 +75,24 @@ def __init__(
bilinear=self.bilinear,
)

def forward(self, x):
def forward(self, x: Tensor) -> Tensor:
return self.net(x)

def training_step(self, batch, batch_nb):
def training_step(self, batch: Tensor, batch_idx: int) -> Dict[str, Any]:
img, mask = batch
img = img.float()
mask = mask.long()
out = self(img)
loss_val = F.cross_entropy(out, mask, ignore_index=250)
loss_val = F.cross_entropy(out, mask, ignore_index=self.ignore_index)
log_dict = {"train_loss": loss_val}
return {"loss": loss_val, "log": log_dict, "progress_bar": log_dict}

def validation_step(self, batch, batch_idx):
def validation_step(self, batch: Tensor, batch_idx: int) -> Dict[str, Any]:
img, mask = batch
img = img.float()
mask = mask.long()
out = self(img)
loss_val = F.cross_entropy(out, mask, ignore_index=250)
loss_val = F.cross_entropy(out, mask, ignore_index=self.ignore_index)
return {"val_loss": loss_val}

def validation_epoch_end(self, outputs):
Expand All @@ -80,7 +106,7 @@ def configure_optimizers(self):
return [opt], [sch]

@staticmethod
def add_model_specific_args(parent_parser):
def add_model_specific_args(parent_parser) -> ArgumentParser:
parser = ArgumentParser(parents=[parent_parser], add_help=False)
parser.add_argument("--lr", type=float, default=0.01, help="adam: learning rate")
parser.add_argument("--num_layers", type=int, default=5, help="number of layers on u-net")
Expand All @@ -92,7 +118,6 @@ def add_model_specific_args(parent_parser):
return parser


@under_review()
def cli_main():
from pl_bolts.datamodules import KittiDataModule

Expand All @@ -115,7 +140,7 @@ def cli_main():
model = SemSegment(**args.__dict__)

# train
trainer = Trainer().from_argparse_args(args)
trainer = Trainer.from_argparse_args(args)
trainer.fit(model, datamodule=dm)


Expand Down
27 changes: 11 additions & 16 deletions pl_bolts/models/vision/unet.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,11 @@
import torch
from torch import nn
from torch import Tensor, nn
from torch.nn import functional as F

from pl_bolts.utils.stability import under_review


@under_review()
class UNet(nn.Module):
"""
"""Pytorch Lightning implementation of U-Net.

Paper: `U-Net: Convolutional Networks for Biomedical Image Segmentation
<https://arxiv.org/abs/1505.04597>`_

Expand All @@ -23,7 +21,7 @@ class UNet(nn.Module):
input_channels: Number of channels in input images (default 3)
num_layers: Number of layers in each side of U-net (default 5)
features_start: Number of features in first layer (default 64)
bilinear: Whether to use bilinear interpolation or transposed convolutions (default) for upsampling.
bilinear: Whether to use bilinear interpolation (True) or transposed convolutions (default) for upsampling.
"""

def __init__(
Expand Down Expand Up @@ -56,7 +54,7 @@ def __init__(

self.layers = nn.ModuleList(layers)

def forward(self, x):
def forward(self, x: Tensor) -> Tensor:
xi = [self.layers[0](x)]
# Down path
for layer in self.layers[1 : self.num_layers]:
Expand All @@ -67,38 +65,35 @@ def forward(self, x):
return self.layers[-1](xi[-1])


@under_review()
class DoubleConv(nn.Module):
"""[ Conv2d => BatchNorm (optional) => ReLU ] x 2."""
"""[ Conv2d => BatchNorm => ReLU ] x 2."""

def __init__(self, in_ch: int, out_ch: int):
super().__init__()
self.net = nn.Sequential(
nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1),
nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1, bias=False),
otaj marked this conversation as resolved.
Show resolved Hide resolved
nn.BatchNorm2d(out_ch),
nn.ReLU(inplace=True),
nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1),
nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1, bias=False),
nn.BatchNorm2d(out_ch),
nn.ReLU(inplace=True),
)

def forward(self, x):
def forward(self, x: Tensor) -> Tensor:
return self.net(x)


@under_review()
class Down(nn.Module):
"""Downscale with MaxPool => DoubleConvolution block."""

def __init__(self, in_ch: int, out_ch: int):
super().__init__()
self.net = nn.Sequential(nn.MaxPool2d(kernel_size=2, stride=2), DoubleConv(in_ch, out_ch))

def forward(self, x):
def forward(self, x: Tensor) -> Tensor:
return self.net(x)


@under_review()
class Up(nn.Module):
"""Upsampling (by either bilinear interpolation or transpose convolutions) followed by concatenation of feature
map from contracting path, followed by DoubleConv."""
Expand All @@ -116,7 +111,7 @@ def __init__(self, in_ch: int, out_ch: int, bilinear: bool = False):

self.conv = DoubleConv(in_ch, out_ch)

def forward(self, x1, x2):
def forward(self, x1: Tensor, x2: Tensor) -> Tensor:
x1 = self.upsample(x1)

# Pad x1 to the size of x2
Expand Down
55 changes: 49 additions & 6 deletions tests/models/test_vision.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,30 @@
import warnings

import pytest
import torch
from packaging import version
from pytorch_lightning import LightningDataModule, Trainer
from pytorch_lightning import __version__ as pl_version
from pytorch_lightning import seed_everything
from pytorch_lightning.callbacks.progress import TQDMProgressBar
from pytorch_lightning.utilities.warnings import PossibleUserWarning
from torch.utils.data import DataLoader

from pl_bolts.datamodules import FashionMNISTDataModule, MNISTDataModule
from pl_bolts.datasets import DummyDataset
from pl_bolts.models.vision import GPT2, ImageGPT, SemSegment, UNet
from pl_bolts.models.vision.unet import DoubleConv, Down, Up


class DummyDataModule(LightningDataModule):
def train_dataloader(self):
train_ds = DummyDataset((3, 35, 120), (35, 120), num_samples=100)
return DataLoader(train_ds, batch_size=1)

def val_dataloader(self):
valid_ds = DummyDataset((3, 35, 120), (35, 120), num_samples=100)
return DataLoader(valid_ds, batch_size=1)


@pytest.mark.skipif(
version.parse(pl_version) > version.parse("1.1.0"), reason="igpt code not updated for latest lightning"
Expand Down Expand Up @@ -68,22 +76,57 @@ def test_gpt2():
model(x)


def test_unet_component(catch_warnings):
x1 = torch.rand(1, 3, 28, 28)
x2 = torch.rand(1, 64, 28, 33)
x3 = torch.rand(1, 32, 64, 69)

doubleConvLayer = DoubleConv(3, 64)
y = doubleConvLayer(x1)
assert y.shape == torch.Size([1, 64, 28, 28])

downLayer = Down(3, 6)
y = downLayer(x1)
assert y.shape == torch.Size([1, 6, 14, 14])

upLayer1 = Up(64, 32, False)
upLayer2 = Up(64, 32, True)
y1 = upLayer1(x2, x3)
y2 = upLayer2(x2, x3)
assert y1.shape == torch.Size([1, 32, 64, 69])
assert y2.shape == torch.Size([1, 32, 64, 69])


@torch.no_grad()
def test_unet():
def test_unet(catch_warnings):
x = torch.rand(10, 3, 28, 28)
model = UNet(num_classes=2)
y = model(x)
assert y.shape == torch.Size([10, 2, 28, 28])


def test_semantic_segmentation(tmpdir):
def test_semantic_segmentation(tmpdir, catch_warnings):
warnings.filterwarnings(
"ignore",
message="The dataloader, val_dataloader 0, does not have many workers which may be a bottleneck",
category=PossibleUserWarning,
)
warnings.filterwarnings(
"ignore",
message="The dataloader, train_dataloader, does not have many workers which may be a bottleneck",
category=PossibleUserWarning,
)
dm = DummyDataModule()

model = SemSegment(num_classes=19)
progress_bar = TQDMProgressBar()

trainer = Trainer(fast_dev_run=True, default_root_dir=tmpdir, callbacks=[progress_bar])
trainer = Trainer(
fast_dev_run=True,
max_epochs=-1,
default_root_dir=tmpdir,
logger=False,
accelerator="auto",
callbacks=[progress_bar],
)
trainer.fit(model, datamodule=dm)
loss = progress_bar.get_metrics(trainer, model)["loss"]

assert float(loss) > 0