Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Unify env vars access #7084

Merged
merged 9 commits into from
Oct 4, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion monai/apps/auto3dseg/bundle_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,9 +42,11 @@
from monai.config import PathLike
from monai.utils import ensure_tuple, look_up_option, run_cmd
from monai.utils.enums import AlgoKeys
from monai.utils.misc import MONAIEnvVars

logger = get_logger(module_name=__name__)
ALGO_HASH = os.environ.get("MONAI_ALGO_HASH", "3e11dd0")
ALGO_HASH = MONAIEnvVars.algo_hash()


__all__ = ["BundleAlgo", "BundleGen"]

Expand Down
4 changes: 2 additions & 2 deletions monai/metrics/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,10 +275,10 @@ def get_surface_distance(
"""
lib: ModuleType = torch if isinstance(seg_pred, torch.Tensor) else np
if not seg_gt.any():
dis = lib.inf * lib.ones_like(seg_gt, dtype=lib.float32)
dis = np.inf * lib.ones_like(seg_gt, dtype=lib.float32)
else:
if not lib.any(seg_pred):
dis = lib.inf * lib.ones_like(seg_gt, dtype=lib.float32)
dis = np.inf * lib.ones_like(seg_gt, dtype=lib.float32)
dis = dis[seg_gt]
return convert_to_dst_type(dis, seg_pred, dtype=dis.dtype)[0]
if distance_metric == "euclidean":
Expand Down
2 changes: 1 addition & 1 deletion monai/networks/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -552,7 +552,7 @@ def copy_model_state(
if inplace and isinstance(dst, torch.nn.Module):
if isinstance(dst, (nn.DataParallel, nn.parallel.DistributedDataParallel)):
dst = dst.module
dst.load_state_dict(dst_dict)
dst.load_state_dict(dst_dict) # type: ignore
return dst_dict, updated_keys, unchanged_keys


Expand Down
2 changes: 1 addition & 1 deletion monai/optimizers/novograd.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def __setstate__(self, state):
for group in self.param_groups:
group.setdefault("amsgrad", False)

def step(self, closure: Callable[[], T] | None = None) -> T | None:
def step(self, closure: Callable[[], T] | None = None) -> T | None: # type: ignore
"""Performs a single optimization step.

Arguments:
Expand Down
4 changes: 2 additions & 2 deletions monai/transforms/inverse.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@

from __future__ import annotations

import os
import warnings
from collections.abc import Hashable, Mapping
from contextlib import contextmanager
Expand All @@ -34,6 +33,7 @@
convert_to_numpy,
convert_to_tensor,
)
from monai.utils.misc import MONAIEnvVars

__all__ = ["TraceableTransform", "InvertibleTransform"]

Expand Down Expand Up @@ -70,7 +70,7 @@ class TraceableTransform(Transform):
`MONAI_TRACE_TRANSFORM` when initializing the class.
"""

tracing = os.environ.get("MONAI_TRACE_TRANSFORM", "1") != "0"
tracing = MONAIEnvVars.trace_transform() != "0"

def set_tracing(self, tracing: bool) -> None:
"""Set whether to trace transforms."""
Expand Down
24 changes: 24 additions & 0 deletions monai/utils/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -527,6 +527,30 @@ def debug() -> bool:
def doc_images() -> str | None:
return os.environ.get("MONAI_DOC_IMAGES")

@staticmethod
def algo_hash() -> str | None:
return os.environ.get("MONAI_ALGO_HASH", "e01d67a")

@staticmethod
def trace_transform() -> str | None:
return os.environ.get("MONAI_TRACE_TRANSFORM", "1")

@staticmethod
def eval_expr() -> str | None:
return os.environ.get("MONAI_EVAL_EXPR", "1")

@staticmethod
def allow_missing_reference() -> str | None:
return os.environ.get("MONAI_ALLOW_MISSING_REFERENCE", "1")

@staticmethod
def extra_test_data() -> str | None:
return os.environ.get("MONAI_EXTRA_TEST_DATA", "1")

@staticmethod
def testing_algo_template() -> str | None:
return os.environ.get("MONAI_TESTING_ALGO_TEMPLATE", None)


class ImageMetaKey:
"""
Expand Down
4 changes: 2 additions & 2 deletions tests/test_monai_env_vars.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,15 +21,15 @@ class TestMONAIEnvVars(unittest.TestCase):
@classmethod
def setUpClass(cls):
super(__class__, cls).setUpClass()
cls.orig_value = os.environ.get("MONAI_DEBUG")
cls.orig_value = str(MONAIEnvVars.debug())

@classmethod
def tearDownClass(cls):
if cls.orig_value is not None:
os.environ["MONAI_DEBUG"] = cls.orig_value
else:
os.environ.pop("MONAI_DEBUG")
print("MONAI debug value:", os.environ.get("MONAI_DEBUG"))
print("MONAI debug value:", str(MONAIEnvVars.debug()))
super(__class__, cls).tearDownClass()

def test_monai_env_vars(self):
Expand Down
4 changes: 2 additions & 2 deletions tests/test_monai_utils_misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

from parameterized import parameterized

from monai.utils.misc import check_kwargs_exist_in_class_init, run_cmd, to_tuple_of_dictionaries
from monai.utils.misc import MONAIEnvVars, check_kwargs_exist_in_class_init, run_cmd, to_tuple_of_dictionaries

TO_TUPLE_OF_DICTIONARIES_TEST_CASES = [
({}, tuple(), tuple()),
Expand Down Expand Up @@ -75,7 +75,7 @@ def _custom_user_function(self, cls, *args, **kwargs):

class TestCommandRunner(unittest.TestCase):
def setUp(self):
self.orig_flag = os.environ.get("MONAI_DEBUG")
self.orig_flag = str(MONAIEnvVars.debug())

def tearDown(self):
if self.orig_flag is not None:
Expand Down
3 changes: 2 additions & 1 deletion tests/test_network_consistency.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,10 @@

import monai.networks.nets as nets
from monai.utils import set_determinism
from monai.utils.misc import MONAIEnvVars
from tests.utils import assert_allclose

extra_test_data_dir = os.environ.get("MONAI_EXTRA_TEST_DATA")
extra_test_data_dir = MONAIEnvVars.extra_test_data()

TESTS = []
if extra_test_data_dir is not None:
Expand Down
3 changes: 2 additions & 1 deletion tests/test_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

import monai.transforms as mt
from monai.data import Dataset
from monai.utils.misc import MONAIEnvVars


class FaultyTransform(mt.Transform):
Expand All @@ -31,7 +32,7 @@ class TestTransform(unittest.TestCase):
@classmethod
def setUpClass(cls):
super(__class__, cls).setUpClass()
cls.orig_value = os.environ.get("MONAI_DEBUG")
cls.orig_value = str(MONAIEnvVars.debug())

@classmethod
def tearDownClass(cls):
Expand Down
3 changes: 2 additions & 1 deletion tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
from monai.data.meta_tensor import MetaTensor, get_track_meta
from monai.networks import convert_to_onnx, convert_to_torchscript
from monai.utils import optional_import
from monai.utils.misc import MONAIEnvVars
from monai.utils.module import pytorch_after
from monai.utils.tf32 import detect_default_tf32
from monai.utils.type_conversion import convert_data_type
Expand Down Expand Up @@ -75,7 +76,7 @@ def get_testing_algo_template_path():

https://github.com/Project-MONAI/MONAI/blob/1.1.0/monai/apps/auto3dseg/bundle_gen.py#L380-L381
"""
return os.environ.get("MONAI_TESTING_ALGO_TEMPLATE", None)
return MONAIEnvVars.testing_algo_template()


def clone(data: NdarrayTensor) -> NdarrayTensor:
Expand Down