Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Removed NLPDDPPlugin Import check #4555

Merged
merged 6 commits into from
Jul 15, 2022
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 10 additions & 11 deletions nemo/core/classes/modelPT.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,13 +37,6 @@
from nemo.utils.debug_hook import register_debug_hooks
from nemo.utils.get_rank import get_rank, is_global_rank_zero

try:
from nemo.collections.nlp.parts.nlp_overrides import NLPDDPPlugin

HAVE_NLPPLUGIN = True
except (ImportError, ModuleNotFoundError):
HAVE_NLPPLUGIN = False

__all__ = ['ModelPT']


Expand Down Expand Up @@ -495,10 +488,16 @@ def setup_optimization(self, optim_config: Optional[Union[DictConfig, Dict]] = N
optim_config['sched']['t_max_epochs'] = self._trainer.max_epochs
optim_config['sched']['t_accumulate_grad_batches'] = self._trainer.accumulate_grad_batches
optim_config['sched']['t_limit_train_batches'] = self._trainer.limit_train_batches
optim_config['sched']['t_num_workers'] = self._trainer.num_devices * self._trainer.num_nodes
if HAVE_NLPPLUGIN and isinstance(self._trainer.accelerator.training_type_plugin, NLPDDPPlugin):
app = AppState()
optim_config['sched']['t_num_workers'] = app.data_parallel_size

app_state = AppState()
if app_state.data_parallel_size is not None:
optim_config['sched']['t_num_workers'] = app_state.data_parallel_size
elif app_state.model_parallel_size is None:
optim_config['sched']['t_num_workers'] = self._trainer.num_devices * self._trainer.num_nodes
else:
optim_config['sched']['t_num_workers'] = (
self._trainer.num_devices * self._trainer.num_nodes
) / app_state.model_parallel_size
else:
optim_config['sched']['max_steps'] = self._trainer.max_steps

Expand Down