Skip to content

Commit

Permalink
Add warning in a multiprocessing special case
Browse files Browse the repository at this point in the history
Signed-off-by: Matthias Hadlich <matthiashadlich@posteo.de>
  • Loading branch information
matt3o committed Aug 7, 2023
1 parent 65cf5fe commit 518b967
Showing 1 changed file with 18 additions and 0 deletions.
18 changes: 18 additions & 0 deletions monai/data/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,20 @@

from __future__ import annotations

import warnings

import torch
from torch.utils.data import DataLoader as _TorchDataLoader
from torch.utils.data import Dataset

from monai.apps.utils import get_logger
from monai.data.meta_obj import get_track_meta
from monai.data.utils import list_data_collate, set_rnd, worker_init_fn

__all__ = ["DataLoader"]

logger = get_logger(module_name=__name__)


class DataLoader(_TorchDataLoader):
"""
Expand Down Expand Up @@ -88,4 +94,16 @@ def __init__(self, dataset: Dataset, num_workers: int = 0, **kwargs) -> None:
if "worker_init_fn" not in kwargs:
kwargs["worker_init_fn"] = worker_init_fn

if (
"multiprocessing_context" in kwargs
and kwargs["multiprocessing_context"] == "spawn"
and not get_track_meta()
):
warnings.warn(
"Please be aware: Return type of the dataloader will not be a Tensor as expected but"
" a MetaTensor instead! This is because 'spawn' creates a new process where _TRACK_META"
" is initialized to True again. Context:_TRACK_META is set to False and"
" multiprocessing_context to spawn"
)

super().__init__(dataset=dataset, num_workers=num_workers, **kwargs)

0 comments on commit 518b967

Please sign in to comment.