From 5f17dcee162a06d821bd96681332cf689e7e6be4 Mon Sep 17 00:00:00 2001 From: Kim Ngo <6362111+findkim@users.noreply.github.com> Date: Mon, 31 Jul 2023 19:13:04 -0500 Subject: [PATCH] Fix rank where torch.distributed may not be initialized yet and would not wait for tokenizer file caching (#7061) Signed-off-by: Kim Ngo <6362111+findkim@users.noreply.github.com> Co-authored-by: David --- .../collections/nlp/modules/common/megatron/megatron_utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/nemo/collections/nlp/modules/common/megatron/megatron_utils.py b/nemo/collections/nlp/modules/common/megatron/megatron_utils.py index 68437921f930..d610f5b61c24 100644 --- a/nemo/collections/nlp/modules/common/megatron/megatron_utils.py +++ b/nemo/collections/nlp/modules/common/megatron/megatron_utils.py @@ -21,7 +21,7 @@ import wget from torch.hub import _get_torch_home -from nemo.utils import get_rank, logging +from nemo.utils import logging __all__ = [ "get_megatron_lm_model", @@ -203,7 +203,7 @@ def _download(path: str, url: str): if url is None: return None - if get_rank.is_global_rank_zero() and not os.path.exists(path): + if (not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0) and not os.path.exists(path): os.makedirs(MEGATRON_CACHE, exist_ok=True) logging.info(f"Downloading from {url} to {path}") downloaded_path = wget.download(url)