Skip to content

Commit

Permalink
Free memory in universal checkpointing tests (#6693)
Browse files Browse the repository at this point in the history
Tests in universal checkpointing were not freeing the engine after use
when `reuse_dist_env` was set to `True`, leading to memory leaks.
This PR ensure freeing the engine in the tests and enables
`reuse_dist_env`.

Co-authored-by: Logan Adams <114770087+loadams@users.noreply.github.com>
  • Loading branch information
tohtana and loadams authored Oct 31, 2024
1 parent ff1c543 commit 95ea95f
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 11 deletions.
5 changes: 3 additions & 2 deletions tests/unit/checkpoint/test_universal_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,8 +131,7 @@ def train_save_convert(ds_config, hidden_dim, load_optim, use_torch_adam, dtype,
torch.save((model_state, optimizer_state), os.path.join(tmpdir, "baseline_state.pt"))

dist.barrier()

return model, sd
model.destroy()


@pytest.fixture
Expand Down Expand Up @@ -213,6 +212,8 @@ def _run_test(self, tmpdir, dtype, ds_config, load_optim, use_torch_adam):
univ_model.backward(loss)
univ_model.step()

univ_model.destroy()

@pytest.mark.world_size(2)
def test_dp_world_size_2to2(self, baseline_ws2, tmpdir, dtype, ds_config, load_optim, use_torch_adam):
self._run_test(tmpdir, dtype, ds_config, load_optim, use_torch_adam)
Expand Down
9 changes: 0 additions & 9 deletions tests/unit/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,6 @@
# Worker timeout for tests that hang
DEEPSPEED_TEST_TIMEOUT = int(os.environ.get('DS_UNITTEST_TIMEOUT', '600'))

warn_reuse_dist_env = False


def is_rocm_pytorch():
return hasattr(torch.version, 'hip') and torch.version.hip is not None
Expand Down Expand Up @@ -178,13 +176,6 @@ def _launch_daemonic_procs(self, num_procs, init_method):
print("Ignoring reuse_dist_env for hpu")
self.reuse_dist_env = False

global warn_reuse_dist_env
if self.reuse_dist_env and not warn_reuse_dist_env:
# Currently we see memory leak for tests that reuse distributed environment
print("Ignoring reuse_dist_env and forcibly setting it to False")
warn_reuse_dist_env = True
self.reuse_dist_env = False

if self.reuse_dist_env:
if num_procs not in self._pool_cache:
self._pool_cache[num_procs] = mp.Pool(processes=num_procs)
Expand Down

0 comments on commit 95ea95f

Please sign in to comment.