Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add warning in a multiprocessing special case #6830

Merged
merged 3 commits into from
Aug 7, 2023

Conversation

matt3o
Copy link
Contributor

@matt3o matt3o commented Aug 7, 2023

Related to discussion #6657

This code at least adds a warning if set_track_meta(False) and multiprocessing_context='spawn' are used in the same code. However this warning only triggers if set_track_meta(False) has been called before the DataLoader has been initialized. I will append some example code where this is not True, still the bug is triggered.
Imo this is still a MONAI bug even though in the discussion it was claimed otherwise. The multiprocessing_context='spawn' has unintended consequences and this is only true for MONAI and not for torch. (I believe the problem is that with 'spawn' Python and all the libs get reinitialized and thus _TRACK_META is reset to being True).

Types of changes

  • Non-breaking change (fix or new feature that would not break existing functionality).
  • Breaking change (fix or new feature that would cause existing functionality to change).
  • New tests added to cover the changes.
  • Integration tests passed locally by running ./runtests.sh -f -u --net --coverage.
  • Quick tests passed locally by running ./runtests.sh --quick --unittests --disttests.
  • In-line docstrings updated.
  • Documentation updated, tested make html command in the docs/ folder.

Signed-off-by: Matthias Hadlich <matthiashadlich@posteo.de>
@matt3o matt3o force-pushed the dataloader_warning branch from 2f10523 to 518b967 Compare August 7, 2023 13:10
@matt3o
Copy link
Contributor Author

matt3o commented Aug 7, 2023

import glob
import os
import logging
import tempfile

import numpy as np
import monai.transforms as mt
import torch
from monai.data import ArrayDataset, DataLoader, MetaTensor
from monai.data.dataset import Dataset, PersistentDataset

from monai.data import create_test_image_3d, partition_dataset
from monai.data import set_track_meta

import nibabel as nib

NETWORK_INPUT_SHAPE = (1, 128, 128, 256)
NUM_IMAGES = 1

logger = logging.getLogger("sw_interactive_segmentation")
if logger.hasHandlers():
    logger.handlers.clear()
logger.propagate = False
logger.setLevel(logging.DEBUG)
stream_handler = logging.StreamHandler()
# (%(name)s)
formatter = logging.Formatter(
    fmt="[%(asctime)s.%(msecs)03d][%(levelname)s] %(funcName)s - %(message)s",
    datefmt="%Y-%m-%d %H:%M:%S",
)
stream_handler.setFormatter(formatter)
stream_handler.setLevel(logging.DEBUG)
logger.addHandler(stream_handler)


if __name__ == "__main__":
    print("### Run 1: Should trigger no warning")
    np.random.seed(seed=0)
    with tempfile.TemporaryDirectory() as tmpdirname:
        print(f"generating synthetic data to {tmpdirname} (this may take a while)")
        for i in range(1):
            pred, label = create_test_image_3d(128, 128, 128, num_seg_classes=1, channel_dim=-1, noise_max=0.5)
            n = nib.Nifti1Image(pred, np.eye(4))
            nib.save(n, os.path.join(tmpdirname, f"pred{i:d}.nii.gz"))
            n = nib.Nifti1Image(label, np.eye(4))
            nib.save(n, os.path.join(tmpdirname, f"label{i:d}.nii.gz"))
        print(os.path.join(str(tmpdirname), "pred*.nii.gz"))
        images = sorted(glob.glob(os.path.join(str(tmpdirname), "pred*.nii.gz")))
        labels = sorted(glob.glob(os.path.join(str(tmpdirname), "label*.nii.gz")))
        datalist = [{"image": image, "label": label} for image, label in zip(images, labels)]

        device = "cuda"

        transform = mt.Compose(
            [
                mt.LoadImaged(
                    keys="image",
                    reader="ITKReader",
                    image_only=False,
                    simple_keys=True,
                ),
            ]
        )

        train_ds = Dataset(datalist, transform)

        train_ds2 = Dataset(datalist, transform)

        train_loader = DataLoader(
            train_ds,
            shuffle=True,
            num_workers=1,
            batch_size=1,
            multiprocessing_context="spawn",
        )

        train_loader2 = DataLoader(
            train_ds2,
            shuffle=True,
            num_workers=1,
            batch_size=1,
        )
        set_track_meta(False)

        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        for x in train_loader:
            print(type(x["image"]))

        for x in train_loader2:
            print(type(x["image"]))

        print(type(transform(datalist[0])["image"]))



    print("### Run 2: Should trigger a warning for the first data loader")
    set_track_meta(False)
    np.random.seed(seed=0)
    with tempfile.TemporaryDirectory() as tmpdirname:
        print(f"generating synthetic data to {tmpdirname} (this may take a while)")
        for i in range(1):
            pred, label = create_test_image_3d(128, 128, 128, num_seg_classes=1, channel_dim=-1, noise_max=0.5)
            n = nib.Nifti1Image(pred, np.eye(4))
            nib.save(n, os.path.join(tmpdirname, f"pred{i:d}.nii.gz"))
            n = nib.Nifti1Image(label, np.eye(4))
            nib.save(n, os.path.join(tmpdirname, f"label{i:d}.nii.gz"))
        print(os.path.join(str(tmpdirname), "pred*.nii.gz"))
        images = sorted(glob.glob(os.path.join(str(tmpdirname), "pred*.nii.gz")))
        labels = sorted(glob.glob(os.path.join(str(tmpdirname), "label*.nii.gz")))
        datalist = [{"image": image, "label": label} for image, label in zip(images, labels)]

        device = "cuda"

        transform = mt.Compose(
            [
                mt.LoadImaged(
                    keys="image",
                    reader="ITKReader",
                    image_only=False,
                    simple_keys=True,
                ),
            ]
        )

        train_ds = Dataset(datalist, transform)

        train_ds2 = Dataset(datalist, transform)

        train_loader = DataLoader(
            train_ds,
            shuffle=True,
            num_workers=1,
            batch_size=1,
            multiprocessing_context="spawn",
        )

        train_loader2 = DataLoader(
            train_ds2,
            shuffle=True,
            num_workers=1,
            batch_size=1,
        )

        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        for x in train_loader:
            print(type(x["image"]))

        for x in train_loader2:
            print(type(x["image"]))

        print(type(transform(datalist[0])["image"]))

returns the following output:

(monai_git) mhadlich@i14pc101:~/code/sliding-window-based-interactive-segmentation-of-volumetric-medical-images$ python src/sw_interactive_segmentation/test_scripts/dataloader_bug.py
### Run 1: Should trigger no warning
generating synthetic data to /tmp/tmpfwwf9c0i (this may take a while)
/tmp/tmpfwwf9c0i/pred*.nii.gz
<class 'monai.data.meta_tensor.MetaTensor'>
<class 'torch.Tensor'>
<class 'torch.Tensor'>
### Run 2: Should trigger a warning for the first data loader
generating synthetic data to /tmp/tmpj1rcqbkg (this may take a while)
/tmp/tmpj1rcqbkg/pred*.nii.gz
Please be aware: Return type of the dataloader will not be a Tensor as expected but a MetaTensor instead! This is because 'spawn' creates a new process where _TRACK_META is initialized to True again. Context:_TRACK_META is set to False and multiprocessing_context to spawn
<class 'monai.data.meta_tensor.MetaTensor'>
<class 'torch.Tensor'>
<class 'torch.Tensor'>

Copy link
Contributor

@wyli wyli left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks for looking into this issue! the logger is not used in the PR, please remove it or change warnings.warn to logger.warn.

Signed-off-by: Matthias Hadlich <matthiashadlich@posteo.de>
@matt3o
Copy link
Contributor Author

matt3o commented Aug 7, 2023

thanks for looking into this issue! the logger is not used in the PR, please remove it or change warnings.warn to logger.warn.

Oh yeah, I totally forgot about it. If you think warnings.warn is fine, I will simply remove the import of the logger.

@wyli
Copy link
Contributor

wyli commented Aug 7, 2023

/build

@wyli wyli enabled auto-merge (squash) August 7, 2023 14:28
@wyli wyli merged commit cb257d2 into Project-MONAI:dev Aug 7, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants