From 95ea95fcd642488519bb599e4618507f10f88494 Mon Sep 17 00:00:00 2001 From: Masahiro Tanaka <81312776+tohtana@users.noreply.github.com> Date: Thu, 31 Oct 2024 11:51:11 -0700 Subject: [PATCH] Free memory in universal checkpointing tests (#6693) Tests in universal checkpointing were not freeing the engine after use when `reuse_dist_env` was set to `True`, leading to memory leaks. This PR ensure freeing the engine in the tests and enables `reuse_dist_env`. Co-authored-by: Logan Adams <114770087+loadams@users.noreply.github.com> --- tests/unit/checkpoint/test_universal_checkpoint.py | 5 +++-- tests/unit/common.py | 9 --------- 2 files changed, 3 insertions(+), 11 deletions(-) diff --git a/tests/unit/checkpoint/test_universal_checkpoint.py b/tests/unit/checkpoint/test_universal_checkpoint.py index f2692ecba3a6..27ddf0cdef39 100644 --- a/tests/unit/checkpoint/test_universal_checkpoint.py +++ b/tests/unit/checkpoint/test_universal_checkpoint.py @@ -131,8 +131,7 @@ def train_save_convert(ds_config, hidden_dim, load_optim, use_torch_adam, dtype, torch.save((model_state, optimizer_state), os.path.join(tmpdir, "baseline_state.pt")) dist.barrier() - - return model, sd + model.destroy() @pytest.fixture @@ -213,6 +212,8 @@ def _run_test(self, tmpdir, dtype, ds_config, load_optim, use_torch_adam): univ_model.backward(loss) univ_model.step() + univ_model.destroy() + @pytest.mark.world_size(2) def test_dp_world_size_2to2(self, baseline_ws2, tmpdir, dtype, ds_config, load_optim, use_torch_adam): self._run_test(tmpdir, dtype, ds_config, load_optim, use_torch_adam) diff --git a/tests/unit/common.py b/tests/unit/common.py index 685f943df2fe..f46ac666f27b 100644 --- a/tests/unit/common.py +++ b/tests/unit/common.py @@ -25,8 +25,6 @@ # Worker timeout for tests that hang DEEPSPEED_TEST_TIMEOUT = int(os.environ.get('DS_UNITTEST_TIMEOUT', '600')) -warn_reuse_dist_env = False - def is_rocm_pytorch(): return hasattr(torch.version, 'hip') and torch.version.hip is not None @@ -178,13 +176,6 @@ def _launch_daemonic_procs(self, num_procs, init_method): print("Ignoring reuse_dist_env for hpu") self.reuse_dist_env = False - global warn_reuse_dist_env - if self.reuse_dist_env and not warn_reuse_dist_env: - # Currently we see memory leak for tests that reuse distributed environment - print("Ignoring reuse_dist_env and forcibly setting it to False") - warn_reuse_dist_env = True - self.reuse_dist_env = False - if self.reuse_dist_env: if num_procs not in self._pool_cache: self._pool_cache[num_procs] = mp.Pool(processes=num_procs)