From 70b2ef1cc4560c88d7379780321bd5c576f834d6 Mon Sep 17 00:00:00 2001 From: Mohammad Adil Date: Tue, 11 Feb 2020 16:19:59 -0800 Subject: [PATCH 1/4] Adding MVP example for Validation every epoch with Early Stopping. --- examples/unet_segmentation_3d.ipynb | 121 ++++++++++++++-------------- monai/utils/stopperutils.py | 22 +++++ 2 files changed, 82 insertions(+), 61 deletions(-) create mode 100644 monai/utils/stopperutils.py diff --git a/examples/unet_segmentation_3d.ipynb b/examples/unet_segmentation_3d.ipynb index a5e614199a..f35c1c20c3 100644 --- a/examples/unet_segmentation_3d.ipynb +++ b/examples/unet_segmentation_3d.ipynb @@ -2,19 +2,13 @@ "cells": [ { "cell_type": "code", - "execution_count": 1, + "execution_count": 15, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", - "text": [ - "MONAI version: 0.0.1\n", - "Python version: 3.7.3 (default, Mar 27 2019, 22:11:17) [GCC 7.3.0]\n", - "Numpy version: 1.16.4\n", - "Pytorch version: 1.3.1\n", - "Ignite version: 0.2.1\n" - ] + "text": "MONAI version: 0.0.1\nPython version: 3.8.1 (default, Jan 8 2020, 22:29:32) [GCC 7.3.0]\nNumpy version: 1.18.1\nPytorch version: 1.4.0\nIgnite version: 0.3.0\n" } ], "source": [ @@ -34,8 +28,8 @@ "import matplotlib.pyplot as plt\n", "import nibabel as nib\n", "\n", - "from ignite.engine import Events, create_supervised_trainer\n", - "from ignite.handlers import ModelCheckpoint\n", + "from ignite.engine import Events, create_supervised_trainer, create_supervised_evaluator\n", + "from ignite.handlers import ModelCheckpoint, EarlyStopping\n", "\n", "# assumes the framework is found here, change as necessary\n", "sys.path.append(\"..\")\n", @@ -43,6 +37,8 @@ "from monai import application, data, networks, utils\n", "from monai.data.readers import NiftiDataset\n", "from monai.data.transforms import AddChannel, Transpose, Rescale, ToTensor, UniformRandomPatch, GridPatchDataset\n", + "from monai.networks.metrics.mean_dice import MeanDice\n", + "from monai.utils.stopperutils import stopping_fn_from_metric\n", "\n", "\n", "application.config.print_config()" @@ -50,7 +46,7 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": 3, "metadata": {}, "outputs": [], "source": [ @@ -81,7 +77,7 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 4, "metadata": {}, "outputs": [], "source": [ @@ -99,15 +95,13 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 5, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", - "text": [ - "torch.Size([10, 1, 64, 64, 64]) torch.Size([10, 1, 64, 64, 64])\n" - ] + "text": "torch.Size([10, 1, 64, 64, 64]) torch.Size([10, 1, 64, 64, 64])\n" } ], "source": [ @@ -136,7 +130,7 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 6, "metadata": {}, "outputs": [], "source": [ @@ -157,46 +151,9 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 18, "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Epoch 1 Loss: 0.8619852662086487\n", - "Epoch 2 Loss: 0.8307779431343079\n", - "Epoch 3 Loss: 0.8064168691635132\n", - "Epoch 4 Loss: 0.7981672883033752\n", - "Epoch 5 Loss: 0.7950631976127625\n", - "Epoch 6 Loss: 0.7949732542037964\n", - "Epoch 7 Loss: 0.7963427901268005\n", - "Epoch 8 Loss: 0.7939450144767761\n", - "Epoch 9 Loss: 0.7926643490791321\n", - "Epoch 10 Loss: 0.7911991477012634\n", - "Epoch 11 Loss: 0.7886414527893066\n", - "Epoch 12 Loss: 0.7867528796195984\n", - "Epoch 13 Loss: 0.7857398390769958\n", - "Epoch 14 Loss: 0.7833380699157715\n", - "Epoch 15 Loss: 0.7791398763656616\n", - "Epoch 16 Loss: 0.7720394730567932\n", - "Epoch 17 Loss: 0.7671006917953491\n", - "Epoch 18 Loss: 0.7646064758300781\n", - "Epoch 19 Loss: 0.7672612071037292\n", - "Epoch 20 Loss: 0.7600041627883911\n", - "Epoch 21 Loss: 0.7583478689193726\n", - "Epoch 22 Loss: 0.7571365833282471\n", - "Epoch 23 Loss: 0.7545363306999207\n", - "Epoch 24 Loss: 0.7499511241912842\n", - "Epoch 25 Loss: 0.7481640577316284\n", - "Epoch 26 Loss: 0.7469437122344971\n", - "Epoch 27 Loss: 0.7460543513298035\n", - "Epoch 28 Loss: 0.74577796459198\n", - "Epoch 29 Loss: 0.7429620027542114\n", - "Epoch 30 Loss: 0.7424858808517456\n" - ] - } - ], + "outputs": [], "source": [ "trainEpochs = 30\n", "\n", @@ -218,16 +175,58 @@ "\n", "\n", "loader = DataLoader(ds, batch_size=20, num_workers=8, pin_memory=torch.cuda.is_available())\n", - " \n", + "val_loader = DataLoader(ds, batch_size=20, num_workers=8, pin_memory=torch.cuda.is_available())\n", + " \n" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "metadata": {}, + "outputs": [], + "source": [ + "metrics = {'Mean Dice': MeanDice()}\n", + "evaluator = create_supervised_evaluator(net, metrics, device, True,\n", + " output_transform=lambda x, y, y_pred: (y_pred[1], y))\n", + "\n", + "\n", + "early_stopper = EarlyStopping(patience=4, \n", + " score_function=stopping_fn_from_metric('Mean Dice'),\n", + " trainer=trainer)\n", + "evaluator.add_event_handler(event_name=Events.EPOCH_COMPLETED, handler=early_stopper)\n", + "\n", + "@evaluator.on(Events.EPOCH_COMPLETED)\n", + "def print_metrics(engine):\n", + " for name, value in engine.state.metrics.items():\n", + " print(f\"{name}: {value}\")\n", + "\n", + "@trainer.on(Events.EPOCH_COMPLETED)\n", + "def run_validation(engine):\n", + " evaluator.run(val_loader)\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": "Epoch 1 Loss: 0.9148738384246826\nMean Dice: 0.15293407514691354\nEpoch 2 Loss: 0.8974388837814331\nMean Dice: 0.15451385974884033\nEpoch 3 Loss: 0.9722865223884583\nMean Dice: 0.15694797784090042\nEpoch 4 Loss: 0.8608269095420837\nMean Dice: 0.15774961411952973\nEpoch 5 Loss: 0.8680331110954285\nMean Dice: 0.1586161732673645\nEpoch 6 Loss: 0.862883448600769\nMean Dice: 0.16078170761466026\nEpoch 7 Loss: 0.8934938311576843\nMean Dice: 0.16227573975920678\nEpoch 8 Loss: 0.8121705651283264\nMean Dice: 0.1634128823876381\nEpoch 9 Loss: 0.9347629547119141\nMean Dice: 0.16452214494347572\nEpoch 10 Loss: 0.834267795085907\nMean Dice: 0.16538378223776817\nEpoch 11 Loss: 0.8011281490325928\nMean Dice: 0.16629234701395035\nEpoch 12 Loss: 0.9585697650909424\nMean Dice: 0.16719597205519676\nEpoch 13 Loss: 0.8275477290153503\nMean Dice: 0.16779896840453148\nEpoch 14 Loss: 0.8383943438529968\nMean Dice: 0.16819630116224288\nEpoch 15 Loss: 0.9151533246040344\nMean Dice: 0.1686399482190609\nEpoch 16 Loss: 0.8661578297615051\nMean Dice: 0.169024233520031\nEpoch 17 Loss: 0.8742436170578003\nMean Dice: 0.1694291800260544\nEpoch 18 Loss: 0.9113590717315674\nMean Dice: 0.16978714540600776\nEpoch 19 Loss: 0.7151992321014404\nMean Dice: 0.17010726779699326\nEpoch 20 Loss: 0.8364595770835876\nMean Dice: 0.1703476406633854\nEpoch 21 Loss: 0.8860700130462646\nMean Dice: 0.17060176506638527\nEpoch 22 Loss: 0.9314338564872742\nMean Dice: 0.17088036686182023\nEpoch 23 Loss: 0.8412138223648071\nMean Dice: 0.17105951756238938\nEpoch 24 Loss: 0.8748369216918945\nMean Dice: 0.17121266573667526\nEpoch 25 Loss: 0.6639310121536255\nMean Dice: 0.17138929441571235\nEpoch 26 Loss: 0.7708219289779663\nMean Dice: 0.17157517224550248\nEpoch 27 Loss: 0.8601589798927307\nMean Dice: 0.17178838700056076\nEpoch 28 Loss: 0.8095047473907471\nMean Dice: 0.17197039127349853\nEpoch 29 Loss: 0.962582528591156\nMean Dice: 0.17208218201994896\nEpoch 30 Loss: 0.8462375402450562\nMean Dice: 0.17212391123175622\n" + } + ], + "source": [ "state = trainer.run(loader, trainEpochs)" ] } ], "metadata": { "kernelspec": { - "display_name": "Python 3", + "display_name": "Python 3.7.5 64-bit ('pytorch': conda)", "language": "python", - "name": "python3" + "name": "python37564bitpytorchconda9e7dd2186ac2430b947ee08d8eff35b4" }, "language_info": { "codemirror_mode": { @@ -239,9 +238,9 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.7.3" + "version": "3.8.1-final" } }, "nbformat": 4, "nbformat_minor": 4 -} +} \ No newline at end of file diff --git a/monai/utils/stopperutils.py b/monai/utils/stopperutils.py new file mode 100644 index 0000000000..e40c533c7f --- /dev/null +++ b/monai/utils/stopperutils.py @@ -0,0 +1,22 @@ +# Copyright 2020 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +def stopping_fn_from_metric(metric_name): + """Returns a stopping function for ignite.handlers.EarlyStopping using the given metric name.""" + def stopping_fn(engine): + return engine.state.metrics[metric_name] + return stopping_fn + +def stopping_fn_from_loss(): + """Returns a stopping function for ignite.handlers.EarlyStopping using the loss value.""" + def stopping_fn(engine): + return -engine.state.loss.item() + return stopping_fn From db692009b9ddb9643eb9fa92acd182367f22ffe6 Mon Sep 17 00:00:00 2001 From: Mohammad Adil Date: Wed, 12 Feb 2020 17:23:31 -0800 Subject: [PATCH 2/4] Add every_n_epochs and update python example. --- examples/unet_segmentation_3d.ipynb | 22 ++++++++------- examples/unet_segmentation_3d.py | 43 +++++++++++++++++++++++------ monai/utils/stopperutils.py | 2 +- 3 files changed, 47 insertions(+), 20 deletions(-) diff --git a/examples/unet_segmentation_3d.ipynb b/examples/unet_segmentation_3d.ipynb index f35c1c20c3..dc48b368c1 100644 --- a/examples/unet_segmentation_3d.ipynb +++ b/examples/unet_segmentation_3d.ipynb @@ -2,7 +2,7 @@ "cells": [ { "cell_type": "code", - "execution_count": 15, + "execution_count": 1, "metadata": {}, "outputs": [ { @@ -46,7 +46,7 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 2, "metadata": {}, "outputs": [], "source": [ @@ -77,7 +77,7 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 3, "metadata": {}, "outputs": [], "source": [ @@ -95,7 +95,7 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 4, "metadata": {}, "outputs": [ { @@ -130,7 +130,7 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 5, "metadata": {}, "outputs": [], "source": [ @@ -151,7 +151,7 @@ }, { "cell_type": "code", - "execution_count": 18, + "execution_count": 6, "metadata": {}, "outputs": [], "source": [ @@ -181,10 +181,12 @@ }, { "cell_type": "code", - "execution_count": 19, + "execution_count": 7, "metadata": {}, "outputs": [], "source": [ + "validation_every_N_epochs = 1\n", + "\n", "metrics = {'Mean Dice': MeanDice()}\n", "evaluator = create_supervised_evaluator(net, metrics, device, True,\n", " output_transform=lambda x, y, y_pred: (y_pred[1], y))\n", @@ -200,7 +202,7 @@ " for name, value in engine.state.metrics.items():\n", " print(f\"{name}: {value}\")\n", "\n", - "@trainer.on(Events.EPOCH_COMPLETED)\n", + "@trainer.on(Events.EPOCH_COMPLETED(every=validation_every_N_epochs))\n", "def run_validation(engine):\n", " evaluator.run(val_loader)\n", "\n" @@ -208,13 +210,13 @@ }, { "cell_type": "code", - "execution_count": 20, + "execution_count": 8, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", - "text": "Epoch 1 Loss: 0.9148738384246826\nMean Dice: 0.15293407514691354\nEpoch 2 Loss: 0.8974388837814331\nMean Dice: 0.15451385974884033\nEpoch 3 Loss: 0.9722865223884583\nMean Dice: 0.15694797784090042\nEpoch 4 Loss: 0.8608269095420837\nMean Dice: 0.15774961411952973\nEpoch 5 Loss: 0.8680331110954285\nMean Dice: 0.1586161732673645\nEpoch 6 Loss: 0.862883448600769\nMean Dice: 0.16078170761466026\nEpoch 7 Loss: 0.8934938311576843\nMean Dice: 0.16227573975920678\nEpoch 8 Loss: 0.8121705651283264\nMean Dice: 0.1634128823876381\nEpoch 9 Loss: 0.9347629547119141\nMean Dice: 0.16452214494347572\nEpoch 10 Loss: 0.834267795085907\nMean Dice: 0.16538378223776817\nEpoch 11 Loss: 0.8011281490325928\nMean Dice: 0.16629234701395035\nEpoch 12 Loss: 0.9585697650909424\nMean Dice: 0.16719597205519676\nEpoch 13 Loss: 0.8275477290153503\nMean Dice: 0.16779896840453148\nEpoch 14 Loss: 0.8383943438529968\nMean Dice: 0.16819630116224288\nEpoch 15 Loss: 0.9151533246040344\nMean Dice: 0.1686399482190609\nEpoch 16 Loss: 0.8661578297615051\nMean Dice: 0.169024233520031\nEpoch 17 Loss: 0.8742436170578003\nMean Dice: 0.1694291800260544\nEpoch 18 Loss: 0.9113590717315674\nMean Dice: 0.16978714540600776\nEpoch 19 Loss: 0.7151992321014404\nMean Dice: 0.17010726779699326\nEpoch 20 Loss: 0.8364595770835876\nMean Dice: 0.1703476406633854\nEpoch 21 Loss: 0.8860700130462646\nMean Dice: 0.17060176506638527\nEpoch 22 Loss: 0.9314338564872742\nMean Dice: 0.17088036686182023\nEpoch 23 Loss: 0.8412138223648071\nMean Dice: 0.17105951756238938\nEpoch 24 Loss: 0.8748369216918945\nMean Dice: 0.17121266573667526\nEpoch 25 Loss: 0.6639310121536255\nMean Dice: 0.17138929441571235\nEpoch 26 Loss: 0.7708219289779663\nMean Dice: 0.17157517224550248\nEpoch 27 Loss: 0.8601589798927307\nMean Dice: 0.17178838700056076\nEpoch 28 Loss: 0.8095047473907471\nMean Dice: 0.17197039127349853\nEpoch 29 Loss: 0.962582528591156\nMean Dice: 0.17208218201994896\nEpoch 30 Loss: 0.8462375402450562\nMean Dice: 0.17212391123175622\n" + "text": "Epoch 1 Loss: 0.7907230257987976\nMean Dice: 0.13321990370750428\nEpoch 2 Loss: 0.912385106086731\nMean Dice: 0.1390095144510269\nEpoch 3 Loss: 0.8601841330528259\nMean Dice: 0.1327591508626938\nEpoch 4 Loss: 0.8087334632873535\nMean Dice: 0.13687533736228943\nEpoch 5 Loss: 0.9294923543930054\nMean Dice: 0.13983858823776246\nEpoch 6 Loss: 0.8575614094734192\nMean Dice: 0.13973902463912963\nEpoch 7 Loss: 0.7714702486991882\nMean Dice: 0.1399323374032974\nEpoch 8 Loss: 0.897472083568573\nMean Dice: 0.13747023940086364\nEpoch 9 Loss: 0.9935440421104431\nMean Dice: 0.1336931049823761\nEpoch 10 Loss: 0.7999455332756042\nMean Dice: 0.13065092861652375\nEpoch 11 Loss: 0.9553474187850952\nMean Dice: 0.13199653327465058\n" } ], "source": [ diff --git a/examples/unet_segmentation_3d.py b/examples/unet_segmentation_3d.py index dac5a8bf86..0f4f1d3ecb 100644 --- a/examples/unet_segmentation_3d.py +++ b/examples/unet_segmentation_3d.py @@ -20,8 +20,8 @@ import torch import monai.data.transforms.compose as transforms from torch.utils.tensorboard import SummaryWriter -from ignite.engine import Events, create_supervised_trainer -from ignite.handlers import ModelCheckpoint +from ignite.engine import Events, create_supervised_trainer, create_supervised_evaluator +from ignite.handlers import ModelCheckpoint, EarlyStopping from torch.utils.data import DataLoader from monai import application, networks, utils @@ -29,6 +29,7 @@ from monai.data.transforms import (AddChannel, Rescale, ToTensor, UniformRandomPatch) from monai.application.handlers.stats_handler import StatsHandler from monai.networks.metrics.mean_dice import MeanDice +from monai.utils.stopperutils import stopping_fn_from_metric # assumes the framework is found here, change as necessary sys.path.append("..") @@ -60,7 +61,7 @@ def create_test_image_3d(height, width, depth, num_objs=12, rad_max=30, noise_ma return noisyimage, labels - +# Create a temporary directory and 50 random image, mask paris tempdir = tempfile.mkdtemp() for i in range(50): @@ -75,18 +76,21 @@ def create_test_image_3d(height, width, depth, num_objs=12, rad_max=30, noise_ma images = sorted(glob(os.path.join(tempdir, 'im*.nii.gz'))) segs = sorted(glob(os.path.join(tempdir, 'seg*.nii.gz'))) +# Define transforms for image and segmentation imtrans = transforms.Compose([Rescale(), AddChannel(), UniformRandomPatch((64, 64, 64)), ToTensor()]) - segtrans = transforms.Compose([AddChannel(), UniformRandomPatch((64, 64, 64)), ToTensor()]) +# Define nifti dataset, dataloader. ds = NiftiDataset(images, segs, imtrans, segtrans) - loader = DataLoader(ds, batch_size=10, num_workers=2, pin_memory=torch.cuda.is_available()) im, seg = utils.mathutils.first(loader) print(im.shape, seg.shape) + lr = 1e-3 +train_epochs = 30 +# Create UNet, DiceLoss and Adam optimizer. net = networks.nets.UNet( dimensions=3, in_channels=1, @@ -101,13 +105,12 @@ def create_test_image_3d(height, width, depth, num_objs=12, rad_max=30, noise_ma train_epochs = 3 - +# Since network outputs logits and segmentation, we need a custom function. def _loss_fn(i, j): return loss(i[0], j) - +# Create trainer device = torch.device("cuda:0") - trainer = create_supervised_trainer(net, opt, _loss_fn, device, False, output_transform=lambda x, y, y_pred, loss: [y_pred, loss.item(), y]) @@ -119,7 +122,7 @@ def _loss_fn(i, j): logging.basicConfig(stream=sys.stdout, level=logging.INFO) stats_logger = StatsHandler() -stats_logger.attach(trainer) +# stats_logger.attach(trainer) @trainer.on(Events.EPOCH_COMPLETED) @@ -154,6 +157,28 @@ def log_training_loss(engine): loader = DataLoader(ds, batch_size=20, num_workers=8, pin_memory=torch.cuda.is_available()) +val_loader = DataLoader(ds, batch_size=20, num_workers=8, pin_memory=torch.cuda.is_available()) writer = SummaryWriter() +# Define mean dice metric and Evaluator. +validation_every_N_epochs = 2 + +val_metrics = {'Mean Dice': MeanDice()} +val_stats_handler = StatsHandler() + +evaluator = create_supervised_evaluator(net, val_metrics, device, True, + output_transform=lambda x, y, y_pred: (y_pred[1], y)) +val_stats_handler.attach(evaluator) + +# Add early stopping handler to evaluator. +early_stopper = EarlyStopping(patience=4, + score_function=stopping_fn_from_metric('Mean Dice'), + trainer=trainer) +evaluator.add_event_handler(event_name=Events.EPOCH_COMPLETED, handler=early_stopper) + +@trainer.on(Events.EPOCH_COMPLETED(every=validation_every_N_epochs)) +def run_validation(engine): + evaluator.run(val_loader) + + state = trainer.run(loader, train_epochs) diff --git a/monai/utils/stopperutils.py b/monai/utils/stopperutils.py index e40c533c7f..bbb0c9ec83 100644 --- a/monai/utils/stopperutils.py +++ b/monai/utils/stopperutils.py @@ -18,5 +18,5 @@ def stopping_fn(engine): def stopping_fn_from_loss(): """Returns a stopping function for ignite.handlers.EarlyStopping using the loss value.""" def stopping_fn(engine): - return -engine.state.loss.item() + return -engine.state.output return stopping_fn From 3e390527649950436f828341ef4e60f61d2dee1d Mon Sep 17 00:00:00 2001 From: Mohammad Adil Date: Wed, 12 Feb 2020 18:16:07 -0800 Subject: [PATCH 3/4] Fix flake8 errors. --- examples/unet_segmentation_3d.ipynb | 6 +++--- examples/unet_segmentation_3d.py | 11 ++++++++++- 2 files changed, 13 insertions(+), 4 deletions(-) diff --git a/examples/unet_segmentation_3d.ipynb b/examples/unet_segmentation_3d.ipynb index dc48b368c1..0073da35a8 100644 --- a/examples/unet_segmentation_3d.ipynb +++ b/examples/unet_segmentation_3d.ipynb @@ -185,7 +185,7 @@ "metadata": {}, "outputs": [], "source": [ - "validation_every_N_epochs = 1\n", + "validation_every_n_epochs = 1\n", "\n", "metrics = {'Mean Dice': MeanDice()}\n", "evaluator = create_supervised_evaluator(net, metrics, device, True,\n", @@ -202,7 +202,7 @@ " for name, value in engine.state.metrics.items():\n", " print(f\"{name}: {value}\")\n", "\n", - "@trainer.on(Events.EPOCH_COMPLETED(every=validation_every_N_epochs))\n", + "@trainer.on(Events.EPOCH_COMPLETED(every=validation_every_n_epochs))\n", "def run_validation(engine):\n", " evaluator.run(val_loader)\n", "\n" @@ -216,7 +216,7 @@ { "name": "stdout", "output_type": "stream", - "text": "Epoch 1 Loss: 0.7907230257987976\nMean Dice: 0.13321990370750428\nEpoch 2 Loss: 0.912385106086731\nMean Dice: 0.1390095144510269\nEpoch 3 Loss: 0.8601841330528259\nMean Dice: 0.1327591508626938\nEpoch 4 Loss: 0.8087334632873535\nMean Dice: 0.13687533736228943\nEpoch 5 Loss: 0.9294923543930054\nMean Dice: 0.13983858823776246\nEpoch 6 Loss: 0.8575614094734192\nMean Dice: 0.13973902463912963\nEpoch 7 Loss: 0.7714702486991882\nMean Dice: 0.1399323374032974\nEpoch 8 Loss: 0.897472083568573\nMean Dice: 0.13747023940086364\nEpoch 9 Loss: 0.9935440421104431\nMean Dice: 0.1336931049823761\nEpoch 10 Loss: 0.7999455332756042\nMean Dice: 0.13065092861652375\nEpoch 11 Loss: 0.9553474187850952\nMean Dice: 0.13199653327465058\n" + "text": "Epoch 1 Loss: 0.9377263188362122\nMean Dice: 0.10294072180986405\nEpoch 2 Loss: 0.9320536851882935\nMean Dice: 0.10453949496150017\nEpoch 3 Loss: 0.8483000993728638\nMean Dice: 0.1106492631137371\nEpoch 4 Loss: 0.8565940856933594\nMean Dice: 0.11204138025641441\nEpoch 5 Loss: 0.9039937853813171\nMean Dice: 0.11255192682147026\nEpoch 6 Loss: 0.8410466313362122\nMean Dice: 0.11070839315652847\nEpoch 7 Loss: 0.8163537979125977\nMean Dice: 0.11086311042308808\nEpoch 8 Loss: 0.8029956817626953\nMean Dice: 0.11151894852519036\nEpoch 9 Loss: 0.9074181914329529\nMean Dice: 0.10971762537956238\n" } ], "source": [ diff --git a/examples/unet_segmentation_3d.py b/examples/unet_segmentation_3d.py index 0f4f1d3ecb..61843e752e 100644 --- a/examples/unet_segmentation_3d.py +++ b/examples/unet_segmentation_3d.py @@ -161,7 +161,7 @@ def log_training_loss(engine): writer = SummaryWriter() # Define mean dice metric and Evaluator. -validation_every_N_epochs = 2 +validation_every_n_epochs = 1 val_metrics = {'Mean Dice': MeanDice()} val_stats_handler = StatsHandler() @@ -176,7 +176,16 @@ def log_training_loss(engine): trainer=trainer) evaluator.add_event_handler(event_name=Events.EPOCH_COMPLETED, handler=early_stopper) +<<<<<<< HEAD @trainer.on(Events.EPOCH_COMPLETED(every=validation_every_N_epochs)) +======= +@evaluator.on(Events.EPOCH_COMPLETED) +def print_metrics(engine): + for name, value in engine.state.metrics.items(): + print(f"{name}: {value}") + +@trainer.on(Events.EPOCH_COMPLETED(every=validation_every_n_epochs)) +>>>>>>> Fix flake8 errors. def run_validation(engine): evaluator.run(val_loader) From 417514208378c9d35c8e9c2e7a5434fe5bc65262 Mon Sep 17 00:00:00 2001 From: Mohammad Adil Date: Wed, 12 Feb 2020 23:41:54 -0800 Subject: [PATCH 4/4] Update example to work with StatsHandler. --- examples/unet_segmentation_3d.ipynb | 31 +++++++++++++++-------------- examples/unet_segmentation_3d.py | 19 +++++------------- 2 files changed, 21 insertions(+), 29 deletions(-) diff --git a/examples/unet_segmentation_3d.ipynb b/examples/unet_segmentation_3d.ipynb index 0073da35a8..b1dc49a8b1 100644 --- a/examples/unet_segmentation_3d.ipynb +++ b/examples/unet_segmentation_3d.ipynb @@ -2,7 +2,7 @@ "cells": [ { "cell_type": "code", - "execution_count": 1, + "execution_count": 9, "metadata": {}, "outputs": [ { @@ -22,7 +22,6 @@ "import torch\n", "import torch.nn as nn\n", "from torch.utils.data import DataLoader\n", - "import monai.data.transforms.compose as transforms\n", "\n", "import numpy as np\n", "import matplotlib.pyplot as plt\n", @@ -34,6 +33,8 @@ "# assumes the framework is found here, change as necessary\n", "sys.path.append(\"..\")\n", "\n", + "\n", + "import monai.data.transforms.compose as transforms\n", "from monai import application, data, networks, utils\n", "from monai.data.readers import NiftiDataset\n", "from monai.data.transforms import AddChannel, Transpose, Rescale, ToTensor, UniformRandomPatch, GridPatchDataset\n", @@ -46,7 +47,7 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": 10, "metadata": {}, "outputs": [], "source": [ @@ -77,7 +78,7 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 11, "metadata": {}, "outputs": [], "source": [ @@ -95,7 +96,7 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 12, "metadata": {}, "outputs": [ { @@ -130,7 +131,7 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 13, "metadata": {}, "outputs": [], "source": [ @@ -151,7 +152,7 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 14, "metadata": {}, "outputs": [], "source": [ @@ -181,15 +182,15 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 15, "metadata": {}, "outputs": [], "source": [ "validation_every_n_epochs = 1\n", "\n", - "metrics = {'Mean Dice': MeanDice()}\n", - "evaluator = create_supervised_evaluator(net, metrics, device, True,\n", - " output_transform=lambda x, y, y_pred: (y_pred[1], y))\n", + "val_metrics = {'Mean Dice': MeanDice(add_sigmoid=True)}\n", + "evaluator = create_supervised_evaluator(net, val_metrics, device, True,\n", + " output_transform=lambda x, y, y_pred: (y_pred[0], y))\n", "\n", "\n", "early_stopper = EarlyStopping(patience=4, \n", @@ -198,9 +199,9 @@ "evaluator.add_event_handler(event_name=Events.EPOCH_COMPLETED, handler=early_stopper)\n", "\n", "@evaluator.on(Events.EPOCH_COMPLETED)\n", - "def print_metrics(engine):\n", + "def log_validation_metrics(engine):\n", " for name, value in engine.state.metrics.items():\n", - " print(f\"{name}: {value}\")\n", + " print(\"Validation --\", name, \":\", value)\n", "\n", "@trainer.on(Events.EPOCH_COMPLETED(every=validation_every_n_epochs))\n", "def run_validation(engine):\n", @@ -210,13 +211,13 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": 16, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", - "text": "Epoch 1 Loss: 0.9377263188362122\nMean Dice: 0.10294072180986405\nEpoch 2 Loss: 0.9320536851882935\nMean Dice: 0.10453949496150017\nEpoch 3 Loss: 0.8483000993728638\nMean Dice: 0.1106492631137371\nEpoch 4 Loss: 0.8565940856933594\nMean Dice: 0.11204138025641441\nEpoch 5 Loss: 0.9039937853813171\nMean Dice: 0.11255192682147026\nEpoch 6 Loss: 0.8410466313362122\nMean Dice: 0.11070839315652847\nEpoch 7 Loss: 0.8163537979125977\nMean Dice: 0.11086311042308808\nEpoch 8 Loss: 0.8029956817626953\nMean Dice: 0.11151894852519036\nEpoch 9 Loss: 0.9074181914329529\nMean Dice: 0.10971762537956238\n" + "text": "Epoch 1 Loss: 0.8975554704666138\nValidation -- Mean Dice : 0.11846490800380707\nEpoch 2 Loss: 0.8451039791107178\nValidation -- Mean Dice : 0.12091563045978546\nEpoch 3 Loss: 0.9355515241622925\nValidation -- Mean Dice : 0.12139833569526673\nEpoch 4 Loss: 0.843208909034729\nValidation -- Mean Dice : 0.12108306288719177\nEpoch 5 Loss: 0.8225834965705872\nValidation -- Mean Dice : 0.12179622799158096\nEpoch 6 Loss: 0.957372784614563\nValidation -- Mean Dice : 0.12193384170532226\nEpoch 7 Loss: 0.9011092782020569\nValidation -- Mean Dice : 0.1230143740773201\nEpoch 8 Loss: 0.8651387691497803\nValidation -- Mean Dice : 0.1254110112786293\nEpoch 9 Loss: 0.8767974972724915\nValidation -- Mean Dice : 0.12633273899555206\nEpoch 10 Loss: 0.8193061947822571\nValidation -- Mean Dice : 0.12657881826162337\nEpoch 11 Loss: 0.9466649293899536\nValidation -- Mean Dice : 0.12699378579854964\nEpoch 12 Loss: 0.8258659243583679\nValidation -- Mean Dice : 0.12790720015764237\nEpoch 13 Loss: 0.8661612868309021\nValidation -- Mean Dice : 0.12980296313762665\nEpoch 14 Loss: 0.8039132356643677\nValidation -- Mean Dice : 0.1311295285820961\nEpoch 15 Loss: 0.8050084114074707\nValidation -- Mean Dice : 0.13225494623184203\nEpoch 16 Loss: 0.9048625230789185\nValidation -- Mean Dice : 0.1330576255917549\nEpoch 17 Loss: 0.9179995656013489\nValidation -- Mean Dice : 0.13361359685659407\nEpoch 18 Loss: 0.8956605195999146\nValidation -- Mean Dice : 0.13432369381189346\nEpoch 19 Loss: 0.8029189705848694\nValidation -- Mean Dice : 0.13532216250896453\nEpoch 20 Loss: 0.8359838128089905\nValidation -- Mean Dice : 0.13622953295707702\nEpoch 21 Loss: 0.9225850105285645\nValidation -- Mean Dice : 0.13677610754966735\nEpoch 22 Loss: 0.7023072242736816\nValidation -- Mean Dice : 0.13693425357341765\nEpoch 23 Loss: 0.8776397705078125\nValidation -- Mean Dice : 0.13710424304008484\nEpoch 24 Loss: 0.9571539163589478\nValidation -- Mean Dice : 0.1370883911848068\nEpoch 25 Loss: 0.8877002596855164\nValidation -- Mean Dice : 0.13701471388339997\nEpoch 26 Loss: 0.817417562007904\nValidation -- Mean Dice : 0.13696834743022918\nEpoch 27 Loss: 0.8971314430236816\nValidation -- Mean Dice : 0.1371448516845703\nEpoch 28 Loss: 0.9443905353546143\nValidation -- Mean Dice : 0.13739778995513915\nEpoch 29 Loss: 0.7578094005584717\nValidation -- Mean Dice : 0.137495020031929\nEpoch 30 Loss: 0.7037953734397888\nValidation -- Mean Dice : 0.13759489357471466\n" } ], "source": [ diff --git a/examples/unet_segmentation_3d.py b/examples/unet_segmentation_3d.py index 61843e752e..a1f0acc5bf 100644 --- a/examples/unet_segmentation_3d.py +++ b/examples/unet_segmentation_3d.py @@ -122,7 +122,7 @@ def _loss_fn(i, j): logging.basicConfig(stream=sys.stdout, level=logging.INFO) stats_logger = StatsHandler() -# stats_logger.attach(trainer) +stats_logger.attach(trainer) @trainer.on(Events.EPOCH_COMPLETED) @@ -163,11 +163,11 @@ def log_training_loss(engine): # Define mean dice metric and Evaluator. validation_every_n_epochs = 1 -val_metrics = {'Mean Dice': MeanDice()} -val_stats_handler = StatsHandler() - +val_metrics = {'Mean Dice': MeanDice(add_sigmoid=True)} evaluator = create_supervised_evaluator(net, val_metrics, device, True, - output_transform=lambda x, y, y_pred: (y_pred[1], y)) + output_transform=lambda x, y, y_pred: (y_pred[0], y)) + +val_stats_handler = StatsHandler() val_stats_handler.attach(evaluator) # Add early stopping handler to evaluator. @@ -176,16 +176,7 @@ def log_training_loss(engine): trainer=trainer) evaluator.add_event_handler(event_name=Events.EPOCH_COMPLETED, handler=early_stopper) -<<<<<<< HEAD -@trainer.on(Events.EPOCH_COMPLETED(every=validation_every_N_epochs)) -======= -@evaluator.on(Events.EPOCH_COMPLETED) -def print_metrics(engine): - for name, value in engine.state.metrics.items(): - print(f"{name}: {value}") - @trainer.on(Events.EPOCH_COMPLETED(every=validation_every_n_epochs)) ->>>>>>> Fix flake8 errors. def run_validation(engine): evaluator.run(val_loader)