Skip to content

Commit

Permalink
Unify env vars access (#7084)
Browse files Browse the repository at this point in the history
Fixes #6879.

### Description

Some environment variable called using `os.get.environ("MONAI_VAR")`
instead of `monai.utils.MONAIEnvVars`.

Can't use `monai.utils.MONAIEnvVars` in `monai.utils.module.py` due to
circular imports.

### Types of changes
<!--- Put an `x` in all the boxes that apply, and remove the not
applicable items -->
- [x] Non-breaking change (fix or new feature that would not break
existing functionality).
- [ ] Breaking change (fix or new feature that would cause existing
functionality to change).
- [ ] New tests added to cover the changes.
- [ ] Integration tests passed locally by running `./runtests.sh -f -u
--net --coverage`.
- [x] Quick tests passed locally by running `./runtests.sh --quick
--unittests --disttests`.
- [ ] In-line docstrings updated.
- [ ] Documentation updated, tested `make html` command in the `docs/`
folder.

---------

Signed-off-by: vgrau98 <victor.grau93@gmail.com>
  • Loading branch information
vgrau98 authored Oct 4, 2023
1 parent 65e8f5b commit e8b79f7
Show file tree
Hide file tree
Showing 11 changed files with 43 additions and 14 deletions.
4 changes: 3 additions & 1 deletion monai/apps/auto3dseg/bundle_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,9 +42,11 @@
from monai.config import PathLike
from monai.utils import ensure_tuple, look_up_option, run_cmd
from monai.utils.enums import AlgoKeys
from monai.utils.misc import MONAIEnvVars

logger = get_logger(module_name=__name__)
ALGO_HASH = os.environ.get("MONAI_ALGO_HASH", "3e11dd0")
ALGO_HASH = MONAIEnvVars.algo_hash()


__all__ = ["BundleAlgo", "BundleGen"]

Expand Down
4 changes: 2 additions & 2 deletions monai/metrics/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,10 +275,10 @@ def get_surface_distance(
"""
lib: ModuleType = torch if isinstance(seg_pred, torch.Tensor) else np
if not seg_gt.any():
dis = lib.inf * lib.ones_like(seg_gt, dtype=lib.float32)
dis = np.inf * lib.ones_like(seg_gt, dtype=lib.float32)
else:
if not lib.any(seg_pred):
dis = lib.inf * lib.ones_like(seg_gt, dtype=lib.float32)
dis = np.inf * lib.ones_like(seg_gt, dtype=lib.float32)
dis = dis[seg_gt]
return convert_to_dst_type(dis, seg_pred, dtype=dis.dtype)[0]
if distance_metric == "euclidean":
Expand Down
2 changes: 1 addition & 1 deletion monai/networks/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -552,7 +552,7 @@ def copy_model_state(
if inplace and isinstance(dst, torch.nn.Module):
if isinstance(dst, (nn.DataParallel, nn.parallel.DistributedDataParallel)):
dst = dst.module
dst.load_state_dict(dst_dict)
dst.load_state_dict(dst_dict) # type: ignore
return dst_dict, updated_keys, unchanged_keys


Expand Down
2 changes: 1 addition & 1 deletion monai/optimizers/novograd.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def __setstate__(self, state):
for group in self.param_groups:
group.setdefault("amsgrad", False)

def step(self, closure: Callable[[], T] | None = None) -> T | None:
def step(self, closure: Callable[[], T] | None = None) -> T | None: # type: ignore
"""Performs a single optimization step.
Arguments:
Expand Down
4 changes: 2 additions & 2 deletions monai/transforms/inverse.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@

from __future__ import annotations

import os
import warnings
from collections.abc import Hashable, Mapping
from contextlib import contextmanager
Expand All @@ -34,6 +33,7 @@
convert_to_numpy,
convert_to_tensor,
)
from monai.utils.misc import MONAIEnvVars

__all__ = ["TraceableTransform", "InvertibleTransform"]

Expand Down Expand Up @@ -70,7 +70,7 @@ class TraceableTransform(Transform):
`MONAI_TRACE_TRANSFORM` when initializing the class.
"""

tracing = os.environ.get("MONAI_TRACE_TRANSFORM", "1") != "0"
tracing = MONAIEnvVars.trace_transform() != "0"

def set_tracing(self, tracing: bool) -> None:
"""Set whether to trace transforms."""
Expand Down
24 changes: 24 additions & 0 deletions monai/utils/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -527,6 +527,30 @@ def debug() -> bool:
def doc_images() -> str | None:
return os.environ.get("MONAI_DOC_IMAGES")

@staticmethod
def algo_hash() -> str | None:
return os.environ.get("MONAI_ALGO_HASH", "e01d67a")

@staticmethod
def trace_transform() -> str | None:
return os.environ.get("MONAI_TRACE_TRANSFORM", "1")

@staticmethod
def eval_expr() -> str | None:
return os.environ.get("MONAI_EVAL_EXPR", "1")

@staticmethod
def allow_missing_reference() -> str | None:
return os.environ.get("MONAI_ALLOW_MISSING_REFERENCE", "1")

@staticmethod
def extra_test_data() -> str | None:
return os.environ.get("MONAI_EXTRA_TEST_DATA", "1")

@staticmethod
def testing_algo_template() -> str | None:
return os.environ.get("MONAI_TESTING_ALGO_TEMPLATE", None)


class ImageMetaKey:
"""
Expand Down
4 changes: 2 additions & 2 deletions tests/test_monai_env_vars.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,15 +21,15 @@ class TestMONAIEnvVars(unittest.TestCase):
@classmethod
def setUpClass(cls):
super(__class__, cls).setUpClass()
cls.orig_value = os.environ.get("MONAI_DEBUG")
cls.orig_value = str(MONAIEnvVars.debug())

@classmethod
def tearDownClass(cls):
if cls.orig_value is not None:
os.environ["MONAI_DEBUG"] = cls.orig_value
else:
os.environ.pop("MONAI_DEBUG")
print("MONAI debug value:", os.environ.get("MONAI_DEBUG"))
print("MONAI debug value:", str(MONAIEnvVars.debug()))
super(__class__, cls).tearDownClass()

def test_monai_env_vars(self):
Expand Down
4 changes: 2 additions & 2 deletions tests/test_monai_utils_misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

from parameterized import parameterized

from monai.utils.misc import check_kwargs_exist_in_class_init, run_cmd, to_tuple_of_dictionaries
from monai.utils.misc import MONAIEnvVars, check_kwargs_exist_in_class_init, run_cmd, to_tuple_of_dictionaries

TO_TUPLE_OF_DICTIONARIES_TEST_CASES = [
({}, tuple(), tuple()),
Expand Down Expand Up @@ -75,7 +75,7 @@ def _custom_user_function(self, cls, *args, **kwargs):

class TestCommandRunner(unittest.TestCase):
def setUp(self):
self.orig_flag = os.environ.get("MONAI_DEBUG")
self.orig_flag = str(MONAIEnvVars.debug())

def tearDown(self):
if self.orig_flag is not None:
Expand Down
3 changes: 2 additions & 1 deletion tests/test_network_consistency.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,10 @@

import monai.networks.nets as nets
from monai.utils import set_determinism
from monai.utils.misc import MONAIEnvVars
from tests.utils import assert_allclose

extra_test_data_dir = os.environ.get("MONAI_EXTRA_TEST_DATA")
extra_test_data_dir = MONAIEnvVars.extra_test_data()

TESTS = []
if extra_test_data_dir is not None:
Expand Down
3 changes: 2 additions & 1 deletion tests/test_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

import monai.transforms as mt
from monai.data import Dataset
from monai.utils.misc import MONAIEnvVars


class FaultyTransform(mt.Transform):
Expand All @@ -31,7 +32,7 @@ class TestTransform(unittest.TestCase):
@classmethod
def setUpClass(cls):
super(__class__, cls).setUpClass()
cls.orig_value = os.environ.get("MONAI_DEBUG")
cls.orig_value = str(MONAIEnvVars.debug())

@classmethod
def tearDownClass(cls):
Expand Down
3 changes: 2 additions & 1 deletion tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
from monai.data.meta_tensor import MetaTensor, get_track_meta
from monai.networks import convert_to_onnx, convert_to_torchscript
from monai.utils import optional_import
from monai.utils.misc import MONAIEnvVars
from monai.utils.module import pytorch_after
from monai.utils.tf32 import detect_default_tf32
from monai.utils.type_conversion import convert_data_type
Expand Down Expand Up @@ -75,7 +76,7 @@ def get_testing_algo_template_path():
https://github.com/Project-MONAI/MONAI/blob/1.1.0/monai/apps/auto3dseg/bundle_gen.py#L380-L381
"""
return os.environ.get("MONAI_TESTING_ALGO_TEMPLATE", None)
return MONAIEnvVars.testing_algo_template()


def clone(data: NdarrayTensor) -> NdarrayTensor:
Expand Down

0 comments on commit e8b79f7

Please sign in to comment.