Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

40-validation-early-stop #73

Merged
merged 5 commits into from
Feb 13, 2020
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
126 changes: 64 additions & 62 deletions examples/unet_segmentation_3d.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -2,19 +2,13 @@
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"execution_count": 9,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"MONAI version: 0.0.1\n",
"Python version: 3.7.3 (default, Mar 27 2019, 22:11:17) [GCC 7.3.0]\n",
"Numpy version: 1.16.4\n",
"Pytorch version: 1.3.1\n",
"Ignite version: 0.2.1\n"
]
"text": "MONAI version: 0.0.1\nPython version: 3.8.1 (default, Jan 8 2020, 22:29:32) [GCC 7.3.0]\nNumpy version: 1.18.1\nPytorch version: 1.4.0\nIgnite version: 0.3.0\n"
}
],
"source": [
Expand All @@ -28,29 +22,32 @@
"import torch\n",
"import torch.nn as nn\n",
"from torch.utils.data import DataLoader\n",
"import monai.data.transforms.compose as transforms\n",
"\n",
"import numpy as np\n",
"import matplotlib.pyplot as plt\n",
"import nibabel as nib\n",
"\n",
"from ignite.engine import Events, create_supervised_trainer\n",
"from ignite.handlers import ModelCheckpoint\n",
"from ignite.engine import Events, create_supervised_trainer, create_supervised_evaluator\n",
"from ignite.handlers import ModelCheckpoint, EarlyStopping\n",
"\n",
"# assumes the framework is found here, change as necessary\n",
"sys.path.append(\"..\")\n",
"\n",
"\n",
"import monai.data.transforms.compose as transforms\n",
"from monai import application, data, networks, utils\n",
"from monai.data.readers import NiftiDataset\n",
"from monai.data.transforms import AddChannel, Transpose, Rescale, ToTensor, UniformRandomPatch, GridPatchDataset\n",
"from monai.networks.metrics.mean_dice import MeanDice\n",
"from monai.utils.stopperutils import stopping_fn_from_metric\n",
"\n",
"\n",
"application.config.print_config()"
]
},
{
"cell_type": "code",
"execution_count": 2,
"execution_count": 10,
"metadata": {},
"outputs": [],
"source": [
Expand Down Expand Up @@ -81,7 +78,7 @@
},
{
"cell_type": "code",
"execution_count": 3,
"execution_count": 11,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -99,15 +96,13 @@
},
{
"cell_type": "code",
"execution_count": 4,
"execution_count": 12,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"torch.Size([10, 1, 64, 64, 64]) torch.Size([10, 1, 64, 64, 64])\n"
]
"text": "torch.Size([10, 1, 64, 64, 64]) torch.Size([10, 1, 64, 64, 64])\n"
}
],
"source": [
Expand Down Expand Up @@ -136,7 +131,7 @@
},
{
"cell_type": "code",
"execution_count": 5,
"execution_count": 13,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -157,46 +152,9 @@
},
{
"cell_type": "code",
"execution_count": 7,
"execution_count": 14,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 1 Loss: 0.8619852662086487\n",
"Epoch 2 Loss: 0.8307779431343079\n",
"Epoch 3 Loss: 0.8064168691635132\n",
"Epoch 4 Loss: 0.7981672883033752\n",
"Epoch 5 Loss: 0.7950631976127625\n",
"Epoch 6 Loss: 0.7949732542037964\n",
"Epoch 7 Loss: 0.7963427901268005\n",
"Epoch 8 Loss: 0.7939450144767761\n",
"Epoch 9 Loss: 0.7926643490791321\n",
"Epoch 10 Loss: 0.7911991477012634\n",
"Epoch 11 Loss: 0.7886414527893066\n",
"Epoch 12 Loss: 0.7867528796195984\n",
"Epoch 13 Loss: 0.7857398390769958\n",
"Epoch 14 Loss: 0.7833380699157715\n",
"Epoch 15 Loss: 0.7791398763656616\n",
"Epoch 16 Loss: 0.7720394730567932\n",
"Epoch 17 Loss: 0.7671006917953491\n",
"Epoch 18 Loss: 0.7646064758300781\n",
"Epoch 19 Loss: 0.7672612071037292\n",
"Epoch 20 Loss: 0.7600041627883911\n",
"Epoch 21 Loss: 0.7583478689193726\n",
"Epoch 22 Loss: 0.7571365833282471\n",
"Epoch 23 Loss: 0.7545363306999207\n",
"Epoch 24 Loss: 0.7499511241912842\n",
"Epoch 25 Loss: 0.7481640577316284\n",
"Epoch 26 Loss: 0.7469437122344971\n",
"Epoch 27 Loss: 0.7460543513298035\n",
"Epoch 28 Loss: 0.74577796459198\n",
"Epoch 29 Loss: 0.7429620027542114\n",
"Epoch 30 Loss: 0.7424858808517456\n"
]
}
],
"outputs": [],
"source": [
"trainEpochs = 30\n",
"\n",
Expand All @@ -218,16 +176,60 @@
"\n",
"\n",
"loader = DataLoader(ds, batch_size=20, num_workers=8, pin_memory=torch.cuda.is_available())\n",
" \n",
"val_loader = DataLoader(ds, batch_size=20, num_workers=8, pin_memory=torch.cuda.is_available())\n",
" \n"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {},
"outputs": [],
"source": [
"validation_every_n_epochs = 1\n",
"\n",
"val_metrics = {'Mean Dice': MeanDice(add_sigmoid=True)}\n",
"evaluator = create_supervised_evaluator(net, val_metrics, device, True,\n",
" output_transform=lambda x, y, y_pred: (y_pred[0], y))\n",
"\n",
"\n",
"early_stopper = EarlyStopping(patience=4, \n",
" score_function=stopping_fn_from_metric('Mean Dice'),\n",
" trainer=trainer)\n",
"evaluator.add_event_handler(event_name=Events.EPOCH_COMPLETED, handler=early_stopper)\n",
"\n",
"@evaluator.on(Events.EPOCH_COMPLETED)\n",
"def log_validation_metrics(engine):\n",
" for name, value in engine.state.metrics.items():\n",
" print(\"Validation --\", name, \":\", value)\n",
"\n",
"@trainer.on(Events.EPOCH_COMPLETED(every=validation_every_n_epochs))\n",
"def run_validation(engine):\n",
" evaluator.run(val_loader)\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": "Epoch 1 Loss: 0.8975554704666138\nValidation -- Mean Dice : 0.11846490800380707\nEpoch 2 Loss: 0.8451039791107178\nValidation -- Mean Dice : 0.12091563045978546\nEpoch 3 Loss: 0.9355515241622925\nValidation -- Mean Dice : 0.12139833569526673\nEpoch 4 Loss: 0.843208909034729\nValidation -- Mean Dice : 0.12108306288719177\nEpoch 5 Loss: 0.8225834965705872\nValidation -- Mean Dice : 0.12179622799158096\nEpoch 6 Loss: 0.957372784614563\nValidation -- Mean Dice : 0.12193384170532226\nEpoch 7 Loss: 0.9011092782020569\nValidation -- Mean Dice : 0.1230143740773201\nEpoch 8 Loss: 0.8651387691497803\nValidation -- Mean Dice : 0.1254110112786293\nEpoch 9 Loss: 0.8767974972724915\nValidation -- Mean Dice : 0.12633273899555206\nEpoch 10 Loss: 0.8193061947822571\nValidation -- Mean Dice : 0.12657881826162337\nEpoch 11 Loss: 0.9466649293899536\nValidation -- Mean Dice : 0.12699378579854964\nEpoch 12 Loss: 0.8258659243583679\nValidation -- Mean Dice : 0.12790720015764237\nEpoch 13 Loss: 0.8661612868309021\nValidation -- Mean Dice : 0.12980296313762665\nEpoch 14 Loss: 0.8039132356643677\nValidation -- Mean Dice : 0.1311295285820961\nEpoch 15 Loss: 0.8050084114074707\nValidation -- Mean Dice : 0.13225494623184203\nEpoch 16 Loss: 0.9048625230789185\nValidation -- Mean Dice : 0.1330576255917549\nEpoch 17 Loss: 0.9179995656013489\nValidation -- Mean Dice : 0.13361359685659407\nEpoch 18 Loss: 0.8956605195999146\nValidation -- Mean Dice : 0.13432369381189346\nEpoch 19 Loss: 0.8029189705848694\nValidation -- Mean Dice : 0.13532216250896453\nEpoch 20 Loss: 0.8359838128089905\nValidation -- Mean Dice : 0.13622953295707702\nEpoch 21 Loss: 0.9225850105285645\nValidation -- Mean Dice : 0.13677610754966735\nEpoch 22 Loss: 0.7023072242736816\nValidation -- Mean Dice : 0.13693425357341765\nEpoch 23 Loss: 0.8776397705078125\nValidation -- Mean Dice : 0.13710424304008484\nEpoch 24 Loss: 0.9571539163589478\nValidation -- Mean Dice : 0.1370883911848068\nEpoch 25 Loss: 0.8877002596855164\nValidation -- Mean Dice : 0.13701471388339997\nEpoch 26 Loss: 0.817417562007904\nValidation -- Mean Dice : 0.13696834743022918\nEpoch 27 Loss: 0.8971314430236816\nValidation -- Mean Dice : 0.1371448516845703\nEpoch 28 Loss: 0.9443905353546143\nValidation -- Mean Dice : 0.13739778995513915\nEpoch 29 Loss: 0.7578094005584717\nValidation -- Mean Dice : 0.137495020031929\nEpoch 30 Loss: 0.7037953734397888\nValidation -- Mean Dice : 0.13759489357471466\n"
}
],
"source": [
"state = trainer.run(loader, trainEpochs)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"display_name": "Python 3.7.5 64-bit ('pytorch': conda)",
"language": "python",
"name": "python3"
"name": "python37564bitpytorchconda9e7dd2186ac2430b947ee08d8eff35b4"
},
"language_info": {
"codemirror_mode": {
Expand All @@ -239,9 +241,9 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.3"
"version": "3.8.1-final"
}
},
"nbformat": 4,
"nbformat_minor": 4
}
}
41 changes: 33 additions & 8 deletions examples/unet_segmentation_3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,15 +20,16 @@
import torch
import monai.data.transforms.compose as transforms
from torch.utils.tensorboard import SummaryWriter
from ignite.engine import Events, create_supervised_trainer
from ignite.handlers import ModelCheckpoint
from ignite.engine import Events, create_supervised_trainer, create_supervised_evaluator
from ignite.handlers import ModelCheckpoint, EarlyStopping
from torch.utils.data import DataLoader

from monai import application, networks, utils
from monai.data.readers import NiftiDataset
from monai.data.transforms import (AddChannel, Rescale, ToTensor, UniformRandomPatch)
from monai.application.handlers.stats_handler import StatsHandler
from monai.networks.metrics.mean_dice import MeanDice
from monai.utils.stopperutils import stopping_fn_from_metric

# assumes the framework is found here, change as necessary
sys.path.append("..")
Expand Down Expand Up @@ -60,7 +61,7 @@ def create_test_image_3d(height, width, depth, num_objs=12, rad_max=30, noise_ma

return noisyimage, labels


# Create a temporary directory and 50 random image, mask paris
tempdir = tempfile.mkdtemp()

for i in range(50):
Expand All @@ -75,18 +76,21 @@ def create_test_image_3d(height, width, depth, num_objs=12, rad_max=30, noise_ma
images = sorted(glob(os.path.join(tempdir, 'im*.nii.gz')))
segs = sorted(glob(os.path.join(tempdir, 'seg*.nii.gz')))

# Define transforms for image and segmentation
imtrans = transforms.Compose([Rescale(), AddChannel(), UniformRandomPatch((64, 64, 64)), ToTensor()])

segtrans = transforms.Compose([AddChannel(), UniformRandomPatch((64, 64, 64)), ToTensor()])

# Define nifti dataset, dataloader.
ds = NiftiDataset(images, segs, imtrans, segtrans)

loader = DataLoader(ds, batch_size=10, num_workers=2, pin_memory=torch.cuda.is_available())
im, seg = utils.mathutils.first(loader)
print(im.shape, seg.shape)


lr = 1e-3
train_epochs = 30

# Create UNet, DiceLoss and Adam optimizer.
net = networks.nets.UNet(
dimensions=3,
in_channels=1,
Expand All @@ -101,13 +105,12 @@ def create_test_image_3d(height, width, depth, num_objs=12, rad_max=30, noise_ma

train_epochs = 3


# Since network outputs logits and segmentation, we need a custom function.
def _loss_fn(i, j):
return loss(i[0], j)


# Create trainer
device = torch.device("cuda:0")

trainer = create_supervised_trainer(net, opt, _loss_fn, device, False,
output_transform=lambda x, y, y_pred, loss: [y_pred, loss.item(), y])

Expand Down Expand Up @@ -154,6 +157,28 @@ def log_training_loss(engine):


loader = DataLoader(ds, batch_size=20, num_workers=8, pin_memory=torch.cuda.is_available())
val_loader = DataLoader(ds, batch_size=20, num_workers=8, pin_memory=torch.cuda.is_available())
writer = SummaryWriter()

# Define mean dice metric and Evaluator.
validation_every_n_epochs = 1

val_metrics = {'Mean Dice': MeanDice(add_sigmoid=True)}
evaluator = create_supervised_evaluator(net, val_metrics, device, True,
output_transform=lambda x, y, y_pred: (y_pred[0], y))

val_stats_handler = StatsHandler()
val_stats_handler.attach(evaluator)

# Add early stopping handler to evaluator.
early_stopper = EarlyStopping(patience=4,
score_function=stopping_fn_from_metric('Mean Dice'),
trainer=trainer)
evaluator.add_event_handler(event_name=Events.EPOCH_COMPLETED, handler=early_stopper)

@trainer.on(Events.EPOCH_COMPLETED(every=validation_every_n_epochs))
def run_validation(engine):
evaluator.run(val_loader)


state = trainer.run(loader, train_epochs)
22 changes: 22 additions & 0 deletions monai/utils/stopperutils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
# Copyright 2020 MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

def stopping_fn_from_metric(metric_name):
Nic-Ma marked this conversation as resolved.
Show resolved Hide resolved
"""Returns a stopping function for ignite.handlers.EarlyStopping using the given metric name."""
def stopping_fn(engine):
return engine.state.metrics[metric_name]
return stopping_fn

def stopping_fn_from_loss():
"""Returns a stopping function for ignite.handlers.EarlyStopping using the loss value."""
def stopping_fn(engine):
return -engine.state.output
return stopping_fn