From 6f1e2261b80cba99129854c459df4661e75f8e05 Mon Sep 17 00:00:00 2001 From: Mingxin Zheng <18563433+mingxin-zheng@users.noreply.github.com> Date: Mon, 27 Mar 2023 14:24:54 +0000 Subject: [PATCH 1/6] Add RankFilter to skip logging when the rank is not meeting user-specified criteria Signed-off-by: Mingxin Zheng <18563433+mingxin-zheng@users.noreply.github.com> --- monai/utils/__init__.py | 2 +- monai/utils/dist.py | 30 +++++++++++++++++- tests/min_tests.py | 1 + tests/test_rankfilter_dist.py | 59 +++++++++++++++++++++++++++++++++++ 4 files changed, 90 insertions(+), 2 deletions(-) create mode 100644 tests/test_rankfilter_dist.py diff --git a/monai/utils/__init__.py b/monai/utils/__init__.py index 75806ce120..5bccaba8a2 100644 --- a/monai/utils/__init__.py +++ b/monai/utils/__init__.py @@ -15,7 +15,7 @@ from .aliases import alias, resolve_name from .decorators import MethodReplacer, RestartGenerator from .deprecate_utils import DeprecatedError, deprecated, deprecated_arg, deprecated_arg_default -from .dist import evenly_divisible_all_gather, get_dist_device, string_list_all_gather +from .dist import RankFilter, evenly_divisible_all_gather, get_dist_device, string_list_all_gather from .enums import ( Average, BlendMode, diff --git a/monai/utils/dist.py b/monai/utils/dist.py index 546058c93e..56dcfa19f1 100644 --- a/monai/utils/dist.py +++ b/monai/utils/dist.py @@ -12,6 +12,8 @@ from __future__ import annotations import sys +from collections.abc import Callable +from logging import Filter if sys.version_info >= (3, 8): from typing import Literal @@ -26,7 +28,7 @@ idist, has_ignite = optional_import("ignite", IgniteInfo.OPT_IMPORT_VERSION, min_version, "distributed") -__all__ = ["get_dist_device", "evenly_divisible_all_gather", "string_list_all_gather"] +__all__ = ["get_dist_device", "evenly_divisible_all_gather", "string_list_all_gather", "RankFilter"] def get_dist_device(): @@ -174,3 +176,29 @@ def string_list_all_gather(strings: list[str], delimiter: str = "\t") -> list[st _gathered = [bytearray(g.tolist()).decode("utf-8").split(delimiter) for g in gathered] return [i for k in _gathered for i in k] + + +class RankFilter(Filter): + """ + The RankFilter class is a convenient filter that extends the Filter class in the Python logging module. + The purpose is to control which log records are processed based on the rank in a distributed environment. + + Args: + filter_fn: an optional lambda function func used as the filtering criteria. + The default function logs only if the rank of the process is 0, + but the user can define their own function to implement custom filtering logic. + """ + + def __init__(self, filter_fn: Callable = lambda rank: rank == 0): + super().__init__() + self.filter_fn: Callable = filter_fn + if dist.is_available() and dist.is_initialized(): + self.rank: int = dist.get_rank() + else: + raise ValueError("The torch.distributed is either unavailable and uninitiated.") + + def filter(self, *_args): + if self.filter_fn(self.rank): + return True + else: + return False diff --git a/tests/min_tests.py b/tests/min_tests.py index 05f117013e..ab5c1db826 100644 --- a/tests/min_tests.py +++ b/tests/min_tests.py @@ -156,6 +156,7 @@ def run_testsuit(): "test_rand_zoom", "test_rand_zoomd", "test_randtorchvisiond", + "test_rankfilter_dist", "test_resample_backends", "test_resize", "test_resized", diff --git a/tests/test_rankfilter_dist.py b/tests/test_rankfilter_dist.py new file mode 100644 index 0000000000..d7ff331020 --- /dev/null +++ b/tests/test_rankfilter_dist.py @@ -0,0 +1,59 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import logging +import os +import tempfile +import unittest + +import torch +import torch.distributed as dist + +from monai.utils import RankFilter +from tests.utils import DistCall, DistTestCase + + +class DistributedRankFilterTest(DistTestCase): + def setUp(self): + self.log_dir = tempfile.TemporaryDirectory() + + @DistCall(nnodes=1, nproc_per_node=2) + def test_even(self): + logger = logging.getLogger(__name__) + log_filename = os.path.join(self.log_dir.name, "records.log") + h1 = logging.FileHandler(filename=log_filename) + h1.setLevel(logging.WARNING) + + logger.addHandler(h1) + + if torch.cuda.device_count() > 1: + dist.init_process_group(backend="nccl", init_method="env://") + rank_filer = RankFilter() + logger.addFilter(rank_filer) + + logger.warning("test_warnings") + + dist.barrier() + if dist.get_rank() == 0: + with open(log_filename) as file: + lines = [line.rstrip() for line in file] + log_message = " ".join(lines) + + assert log_message.count("test_warnings") == 1 + + def tearDown(self) -> None: + self.log_dir.cleanup() + + +if __name__ == "__main__": + unittest.main() From b35230ec0159dfbfe9379b087966508fac31595d Mon Sep 17 00:00:00 2001 From: Mingxin Zheng <18563433+mingxin-zheng@users.noreply.github.com> Date: Mon, 27 Mar 2023 15:03:13 +0000 Subject: [PATCH 2/6] fix comment Signed-off-by: Mingxin Zheng <18563433+mingxin-zheng@users.noreply.github.com> --- monai/utils/dist.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/monai/utils/dist.py b/monai/utils/dist.py index 56dcfa19f1..5f3d26caaa 100644 --- a/monai/utils/dist.py +++ b/monai/utils/dist.py @@ -198,7 +198,4 @@ def __init__(self, filter_fn: Callable = lambda rank: rank == 0): raise ValueError("The torch.distributed is either unavailable and uninitiated.") def filter(self, *_args): - if self.filter_fn(self.rank): - return True - else: - return False + return self.filter_fn(self.rank) From a452056caada43c0c802a810fe2b59d009878b8f Mon Sep 17 00:00:00 2001 From: Mingxin Zheng <18563433+mingxin-zheng@users.noreply.github.com> Date: Mon, 27 Mar 2023 15:17:37 +0000 Subject: [PATCH 3/6] change to warning in non-dist call Signed-off-by: Mingxin Zheng <18563433+mingxin-zheng@users.noreply.github.com> --- monai/utils/dist.py | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/monai/utils/dist.py b/monai/utils/dist.py index 5f3d26caaa..87a95473b1 100644 --- a/monai/utils/dist.py +++ b/monai/utils/dist.py @@ -12,6 +12,7 @@ from __future__ import annotations import sys +import warnings from collections.abc import Callable from logging import Filter @@ -184,18 +185,23 @@ class RankFilter(Filter): The purpose is to control which log records are processed based on the rank in a distributed environment. Args: - filter_fn: an optional lambda function func used as the filtering criteria. + rank: the rank of the process in the torch.distributed. Default is None and then it will use dist.get_rank(). + filter_fn: an optional lambda function used as the filtering criteria. The default function logs only if the rank of the process is 0, but the user can define their own function to implement custom filtering logic. """ - def __init__(self, filter_fn: Callable = lambda rank: rank == 0): + def __init__(self, rank: int | None = None, filter_fn: Callable = lambda rank: rank == 0): super().__init__() self.filter_fn: Callable = filter_fn if dist.is_available() and dist.is_initialized(): - self.rank: int = dist.get_rank() + self.rank: int = rank if rank else dist.get_rank() else: - raise ValueError("The torch.distributed is either unavailable and uninitiated.") + warnings.warn( + "The torch.distributed is either unavailable and uninitiated when RankFilter is instiantiated. " + "If torch.distributed is used, please ensure that the RankFilter() is called " + "after torch.distributed.init_process_group() in the script." + ) def filter(self, *_args): return self.filter_fn(self.rank) From 66ce46fe2f2383abec0e7f71db8fae7b8152f86a Mon Sep 17 00:00:00 2001 From: Mingxin Zheng <18563433+mingxin-zheng@users.noreply.github.com> Date: Tue, 28 Mar 2023 06:44:28 +0000 Subject: [PATCH 4/6] fix comment Signed-off-by: Mingxin Zheng <18563433+mingxin-zheng@users.noreply.github.com> --- monai/utils/dist.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/monai/utils/dist.py b/monai/utils/dist.py index 87a95473b1..c476ace73b 100644 --- a/monai/utils/dist.py +++ b/monai/utils/dist.py @@ -195,7 +195,7 @@ def __init__(self, rank: int | None = None, filter_fn: Callable = lambda rank: r super().__init__() self.filter_fn: Callable = filter_fn if dist.is_available() and dist.is_initialized(): - self.rank: int = rank if rank else dist.get_rank() + self.rank: int = rank if rank is not None else dist.get_rank() else: warnings.warn( "The torch.distributed is either unavailable and uninitiated when RankFilter is instiantiated. " From e4a50d1405ad07f9182b8d9a801422ac80fbe38d Mon Sep 17 00:00:00 2001 From: Mingxin Zheng <18563433+mingxin-zheng@users.noreply.github.com> Date: Tue, 28 Mar 2023 06:58:41 +0000 Subject: [PATCH 5/6] fix test errors Signed-off-by: Mingxin Zheng <18563433+mingxin-zheng@users.noreply.github.com> --- tests/test_rankfilter_dist.py | 9 ++------- 1 file changed, 2 insertions(+), 7 deletions(-) diff --git a/tests/test_rankfilter_dist.py b/tests/test_rankfilter_dist.py index d7ff331020..92517c6204 100644 --- a/tests/test_rankfilter_dist.py +++ b/tests/test_rankfilter_dist.py @@ -28,7 +28,7 @@ def setUp(self): self.log_dir = tempfile.TemporaryDirectory() @DistCall(nnodes=1, nproc_per_node=2) - def test_even(self): + def test_rankfilter(self): logger = logging.getLogger(__name__) log_filename = os.path.join(self.log_dir.name, "records.log") h1 = logging.FileHandler(filename=log_filename) @@ -36,11 +36,7 @@ def test_even(self): logger.addHandler(h1) - if torch.cuda.device_count() > 1: - dist.init_process_group(backend="nccl", init_method="env://") - rank_filer = RankFilter() - logger.addFilter(rank_filer) - + logger.addFilter(RankFilter()) logger.warning("test_warnings") dist.barrier() @@ -48,7 +44,6 @@ def test_even(self): with open(log_filename) as file: lines = [line.rstrip() for line in file] log_message = " ".join(lines) - assert log_message.count("test_warnings") == 1 def tearDown(self) -> None: From d99fd6c9176b07b64f7cf7b04af1c56104ab4f88 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 28 Mar 2023 06:59:47 +0000 Subject: [PATCH 6/6] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/test_rankfilter_dist.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/test_rankfilter_dist.py b/tests/test_rankfilter_dist.py index 92517c6204..4dcd637c56 100644 --- a/tests/test_rankfilter_dist.py +++ b/tests/test_rankfilter_dist.py @@ -16,7 +16,6 @@ import tempfile import unittest -import torch import torch.distributed as dist from monai.utils import RankFilter