Skip to content

Commit

Permalink
[DLMED] add args and update default (#3418)
Browse files Browse the repository at this point in the history
Signed-off-by: Nic Ma <nma@nvidia.com>
  • Loading branch information
Nic-Ma authored Nov 30, 2021
1 parent 071264d commit ff9bbfa
Show file tree
Hide file tree
Showing 4 changed files with 45 additions and 16 deletions.
19 changes: 15 additions & 4 deletions monai/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@

import numpy as np
import torch
from torch.serialization import DEFAULT_PROTOCOL
from torch.utils.data import Dataset as _TorchDataset
from torch.utils.data import Subset

Expand Down Expand Up @@ -153,7 +154,7 @@ def __init__(
cache_dir: Optional[Union[Path, str]],
hash_func: Callable[..., bytes] = pickle_hashing,
pickle_module: str = "pickle",
pickle_protocol=pickle.DEFAULT_PROTOCOL,
pickle_protocol: int = DEFAULT_PROTOCOL,
) -> None:
"""
Args:
Expand All @@ -169,7 +170,12 @@ def __init__(
If `cache_dir` is `None`, there is effectively no caching.
hash_func: a callable to compute hash from data items to be cached.
defaults to `monai.data.utils.pickle_hashing`.
pickle_module: string representing the module used for pickling metadata and objects, default to `"pickle"`.
pickle_module: string representing the module used for pickling metadata and objects,
default to `"pickle"`. due to the pickle limitation in multi-processing of Dataloader,
we can't use `pickle` as arg directly, so here we use a string name instead.
if want to use other pickle module at runtime, just register like:
>>> from monai.data import utils
>>> utils.SUPPORTED_PICKLE_MOD["test"] = other_pickle
this arg is used by `torch.save`, for more details, please check:
https://pytorch.org/docs/stable/generated/torch.save.html#torch.save,
and ``monai.data.utils.SUPPORTED_PICKLE_MOD``.
Expand Down Expand Up @@ -319,7 +325,7 @@ def __init__(
cache_dir: Optional[Union[Path, str]],
hash_func: Callable[..., bytes] = pickle_hashing,
pickle_module: str = "pickle",
pickle_protocol=pickle.DEFAULT_PROTOCOL,
pickle_protocol: int = DEFAULT_PROTOCOL,
) -> None:
"""
Args:
Expand All @@ -336,7 +342,12 @@ def __init__(
If `cache_dir` is `None`, there is effectively no caching.
hash_func: a callable to compute hash from data items to be cached.
defaults to `monai.data.utils.pickle_hashing`.
pickle_module: string representing the module used for pickling metadata and objects, default to `"pickle"`.
pickle_module: string representing the module used for pickling metadata and objects,
default to `"pickle"`. due to the pickle limitation in multi-processing of Dataloader,
we can't use `pickle` as arg directly, so here we use a string name instead.
if want to use other pickle module at runtime, just register like:
>>> from monai.data import utils
>>> utils.SUPPORTED_PICKLE_MOD["test"] = other_pickle
this arg is used by `torch.save`, for more details, please check:
https://pytorch.org/docs/stable/generated/torch.save.html#torch.save,
and ``monai.data.utils.SUPPORTED_PICKLE_MOD``.
Expand Down
14 changes: 13 additions & 1 deletion monai/optimizers/lr_finder.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import pickle
import warnings
from functools import partial
from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Tuple, Type, Union
Expand All @@ -17,6 +18,7 @@
import torch
import torch.nn as nn
from torch.optim import Optimizer
from torch.serialization import DEFAULT_PROTOCOL
from torch.utils.data import DataLoader

from monai.networks.utils import eval_mode
Expand Down Expand Up @@ -183,6 +185,8 @@ def __init__(
memory_cache: bool = True,
cache_dir: Optional[str] = None,
amp: bool = False,
pickle_module=pickle,
pickle_protocol: int = DEFAULT_PROTOCOL,
verbose: bool = True,
) -> None:
"""Constructor.
Expand All @@ -202,6 +206,12 @@ def __init__(
specified, system-wide temporary directory is used. Notice that this
parameter will be ignored if `memory_cache` is True.
amp: use Automatic Mixed Precision
pickle_module: module used for pickling metadata and objects, default to `pickle`.
this arg is used by `torch.save`, for more details, please check:
https://pytorch.org/docs/stable/generated/torch.save.html#torch.save.
pickle_protocol: can be specified to override the default protocol, default to `2`.
this arg is used by `torch.save`, for more details, please check:
https://pytorch.org/docs/stable/generated/torch.save.html#torch.save.
verbose: verbose output
Returns:
None
Expand All @@ -221,7 +231,9 @@ def __init__(
# Save the original state of the model and optimizer so they can be restored if
# needed
self.model_device = next(self.model.parameters()).device
self.state_cacher = StateCacher(memory_cache, cache_dir=cache_dir)
self.state_cacher = StateCacher(
in_memory=memory_cache, cache_dir=cache_dir, pickle_module=pickle_module, pickle_protocol=pickle_protocol
)
self.state_cacher.store("model", self.model.state_dict())
self.state_cacher.store("optimizer", self.optimizer.state_dict())

Expand Down
18 changes: 8 additions & 10 deletions monai/utils/state_cacher.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from typing import Dict, Optional

import torch
from torch.serialization import DEFAULT_PROTOCOL

from monai.config.type_definitions import PathLike

Expand Down Expand Up @@ -43,7 +44,7 @@ def __init__(
cache_dir: Optional[PathLike] = None,
allow_overwrite: bool = True,
pickle_module=pickle,
pickle_protocol=pickle.DEFAULT_PROTOCOL,
pickle_protocol: int = DEFAULT_PROTOCOL,
) -> None:
"""Constructor.
Expand All @@ -65,19 +66,16 @@ def __init__(
"""
self.in_memory = in_memory
self.cache_dir = cache_dir
self.cache_dir = tempfile.gettempdir() if cache_dir is None else cache_dir
if not os.path.isdir(self.cache_dir):
raise ValueError("Given `cache_dir` is not a valid directory.")

self.allow_overwrite = allow_overwrite
self.pickle_module = pickle_module
self.pickle_protocol = pickle_protocol
self.cached: Dict = {}

if self.cache_dir is None:
self.cache_dir = tempfile.gettempdir()
elif not os.path.isdir(self.cache_dir):
raise ValueError("Given `cache_dir` is not a valid directory.")

self.cached: Dict[str, str] = {}

def store(self, key, data_obj, pickle_module=None, pickle_protocol=None):
def store(self, key, data_obj, pickle_module=None, pickle_protocol: Optional[int] = None):
"""
Store a given object with the given key name.
Expand Down
10 changes: 9 additions & 1 deletion tests/test_lr_finder.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
# limitations under the License.

import os
import pickle
import random
import sys
import unittest
Expand Down Expand Up @@ -78,7 +79,14 @@ def test_lr_finder(self):
learning_rate = 1e-5
optimizer = torch.optim.Adam(model.parameters(), learning_rate)

lr_finder = LearningRateFinder(model, optimizer, loss_function, device=device)
lr_finder = LearningRateFinder(
model=model,
optimizer=optimizer,
criterion=loss_function,
device=device,
pickle_module=pickle,
pickle_protocol=4,
)
lr_finder.range_test(train_loader, val_loader=train_loader, end_lr=10, num_iter=5)
print(lr_finder.get_steepest_gradient(0, 0)[0])

Expand Down

0 comments on commit ff9bbfa

Please sign in to comment.