diff --git a/monai/data/dataset.py b/monai/data/dataset.py index 989de09c18..f1b481d598 100644 --- a/monai/data/dataset.py +++ b/monai/data/dataset.py @@ -26,6 +26,7 @@ import numpy as np import torch +from torch.serialization import DEFAULT_PROTOCOL from torch.utils.data import Dataset as _TorchDataset from torch.utils.data import Subset @@ -153,7 +154,7 @@ def __init__( cache_dir: Optional[Union[Path, str]], hash_func: Callable[..., bytes] = pickle_hashing, pickle_module: str = "pickle", - pickle_protocol=pickle.DEFAULT_PROTOCOL, + pickle_protocol: int = DEFAULT_PROTOCOL, ) -> None: """ Args: @@ -169,7 +170,12 @@ def __init__( If `cache_dir` is `None`, there is effectively no caching. hash_func: a callable to compute hash from data items to be cached. defaults to `monai.data.utils.pickle_hashing`. - pickle_module: string representing the module used for pickling metadata and objects, default to `"pickle"`. + pickle_module: string representing the module used for pickling metadata and objects, + default to `"pickle"`. due to the pickle limitation in multi-processing of Dataloader, + we can't use `pickle` as arg directly, so here we use a string name instead. + if want to use other pickle module at runtime, just register like: + >>> from monai.data import utils + >>> utils.SUPPORTED_PICKLE_MOD["test"] = other_pickle this arg is used by `torch.save`, for more details, please check: https://pytorch.org/docs/stable/generated/torch.save.html#torch.save, and ``monai.data.utils.SUPPORTED_PICKLE_MOD``. @@ -319,7 +325,7 @@ def __init__( cache_dir: Optional[Union[Path, str]], hash_func: Callable[..., bytes] = pickle_hashing, pickle_module: str = "pickle", - pickle_protocol=pickle.DEFAULT_PROTOCOL, + pickle_protocol: int = DEFAULT_PROTOCOL, ) -> None: """ Args: @@ -336,7 +342,12 @@ def __init__( If `cache_dir` is `None`, there is effectively no caching. hash_func: a callable to compute hash from data items to be cached. defaults to `monai.data.utils.pickle_hashing`. - pickle_module: string representing the module used for pickling metadata and objects, default to `"pickle"`. + pickle_module: string representing the module used for pickling metadata and objects, + default to `"pickle"`. due to the pickle limitation in multi-processing of Dataloader, + we can't use `pickle` as arg directly, so here we use a string name instead. + if want to use other pickle module at runtime, just register like: + >>> from monai.data import utils + >>> utils.SUPPORTED_PICKLE_MOD["test"] = other_pickle this arg is used by `torch.save`, for more details, please check: https://pytorch.org/docs/stable/generated/torch.save.html#torch.save, and ``monai.data.utils.SUPPORTED_PICKLE_MOD``. diff --git a/monai/optimizers/lr_finder.py b/monai/optimizers/lr_finder.py index d1ab1b44e1..9cd66c226d 100644 --- a/monai/optimizers/lr_finder.py +++ b/monai/optimizers/lr_finder.py @@ -9,6 +9,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import pickle import warnings from functools import partial from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Tuple, Type, Union @@ -17,6 +18,7 @@ import torch import torch.nn as nn from torch.optim import Optimizer +from torch.serialization import DEFAULT_PROTOCOL from torch.utils.data import DataLoader from monai.networks.utils import eval_mode @@ -183,6 +185,8 @@ def __init__( memory_cache: bool = True, cache_dir: Optional[str] = None, amp: bool = False, + pickle_module=pickle, + pickle_protocol: int = DEFAULT_PROTOCOL, verbose: bool = True, ) -> None: """Constructor. @@ -202,6 +206,12 @@ def __init__( specified, system-wide temporary directory is used. Notice that this parameter will be ignored if `memory_cache` is True. amp: use Automatic Mixed Precision + pickle_module: module used for pickling metadata and objects, default to `pickle`. + this arg is used by `torch.save`, for more details, please check: + https://pytorch.org/docs/stable/generated/torch.save.html#torch.save. + pickle_protocol: can be specified to override the default protocol, default to `2`. + this arg is used by `torch.save`, for more details, please check: + https://pytorch.org/docs/stable/generated/torch.save.html#torch.save. verbose: verbose output Returns: None @@ -221,7 +231,9 @@ def __init__( # Save the original state of the model and optimizer so they can be restored if # needed self.model_device = next(self.model.parameters()).device - self.state_cacher = StateCacher(memory_cache, cache_dir=cache_dir) + self.state_cacher = StateCacher( + in_memory=memory_cache, cache_dir=cache_dir, pickle_module=pickle_module, pickle_protocol=pickle_protocol + ) self.state_cacher.store("model", self.model.state_dict()) self.state_cacher.store("optimizer", self.optimizer.state_dict()) diff --git a/monai/utils/state_cacher.py b/monai/utils/state_cacher.py index ddbf6f90c5..2cf0271db2 100644 --- a/monai/utils/state_cacher.py +++ b/monai/utils/state_cacher.py @@ -16,6 +16,7 @@ from typing import Dict, Optional import torch +from torch.serialization import DEFAULT_PROTOCOL from monai.config.type_definitions import PathLike @@ -43,7 +44,7 @@ def __init__( cache_dir: Optional[PathLike] = None, allow_overwrite: bool = True, pickle_module=pickle, - pickle_protocol=pickle.DEFAULT_PROTOCOL, + pickle_protocol: int = DEFAULT_PROTOCOL, ) -> None: """Constructor. @@ -65,19 +66,16 @@ def __init__( """ self.in_memory = in_memory - self.cache_dir = cache_dir + self.cache_dir = tempfile.gettempdir() if cache_dir is None else cache_dir + if not os.path.isdir(self.cache_dir): + raise ValueError("Given `cache_dir` is not a valid directory.") + self.allow_overwrite = allow_overwrite self.pickle_module = pickle_module self.pickle_protocol = pickle_protocol + self.cached: Dict = {} - if self.cache_dir is None: - self.cache_dir = tempfile.gettempdir() - elif not os.path.isdir(self.cache_dir): - raise ValueError("Given `cache_dir` is not a valid directory.") - - self.cached: Dict[str, str] = {} - - def store(self, key, data_obj, pickle_module=None, pickle_protocol=None): + def store(self, key, data_obj, pickle_module=None, pickle_protocol: Optional[int] = None): """ Store a given object with the given key name. diff --git a/tests/test_lr_finder.py b/tests/test_lr_finder.py index 5b730c2a77..c3a7c83448 100644 --- a/tests/test_lr_finder.py +++ b/tests/test_lr_finder.py @@ -10,6 +10,7 @@ # limitations under the License. import os +import pickle import random import sys import unittest @@ -78,7 +79,14 @@ def test_lr_finder(self): learning_rate = 1e-5 optimizer = torch.optim.Adam(model.parameters(), learning_rate) - lr_finder = LearningRateFinder(model, optimizer, loss_function, device=device) + lr_finder = LearningRateFinder( + model=model, + optimizer=optimizer, + criterion=loss_function, + device=device, + pickle_module=pickle, + pickle_protocol=4, + ) lr_finder.range_test(train_loader, val_loader=train_loader, end_lr=10, num_iter=5) print(lr_finder.get_steepest_gradient(0, 0)[0])