forked from NVIDIA/NeMo
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[NeMo-UX] Add token drop callback and optimize mixtral configs (NVIDI…
…A#10361) * add token drop plugin Signed-off-by: Jimmy Zhang <jiemingz@nvidia.com> * add checks Signed-off-by: Jimmy Zhang <jiemingz@nvidia.com> * add expert parallel configs Signed-off-by: Jimmy Zhang <jiemingz@nvidia.com> * Apply isort and black reformatting Signed-off-by: JimmyZhang12 <JimmyZhang12@users.noreply.github.com> * amend comment Signed-off-by: Jimmy Zhang <jiemingz@nvidia.com> * Apply isort and black reformatting Signed-off-by: JimmyZhang12 <JimmyZhang12@users.noreply.github.com> * add comm overlap Signed-off-by: Jimmy Zhang <jiemingz@nvidia.com> * fix rebase errors Signed-off-by: Jimmy Zhang <jiemingz@nvidia.com> * Apply isort and black reformatting Signed-off-by: JimmyZhang12 <JimmyZhang12@users.noreply.github.com> * fix typo Signed-off-by: Jimmy Zhang <jiemingz@nvidia.com> * add test configs Signed-off-by: Jimmy Zhang <jiemingz@nvidia.com> * fix Signed-off-by: Jimmy Zhang <jiemingz@nvidia.com> --------- Signed-off-by: Jimmy Zhang <jiemingz@nvidia.com> Signed-off-by: JimmyZhang12 <JimmyZhang12@users.noreply.github.com> Co-authored-by: Jimmy Zhang <jiemingz@nvidia.com> Co-authored-by: JimmyZhang12 <JimmyZhang12@users.noreply.github.com> Co-authored-by: Pablo Garay <palenq@gmail.com> Co-authored-by: Alexandros Koumparoulis <153118171+akoumpa@users.noreply.github.com> Signed-off-by: George Armstrong <georgea@nvidia.com>
- Loading branch information
1 parent
34041b5
commit 1035c70
Showing
8 changed files
with
233 additions
and
43 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,55 @@ | ||
import pytorch_lightning as pl | ||
from megatron.core import ModelParallelConfig | ||
from pytorch_lightning.callbacks.callback import Callback | ||
|
||
from nemo.lightning.pytorch.strategies.megatron_strategy import MegatronStrategy | ||
|
||
|
||
class MegatronTokenDropCallback(Callback): | ||
""" | ||
A PyTorch Lightning callback to enable token drop for MOEs. Token drop improves performance by better | ||
balancing work across experts, but may affect convergence. | ||
Args: | ||
moe_expert_capacity_factor (float): The capacity factor for all experts | ||
moe_pad_expert_input_to_capacity (bool): Pad the input for each expert to the expert capacity lengt | ||
Example: | ||
>>> callback = MegatronCommOverlapCallback() | ||
>>> trainer = Trainer(callbacks=[callback]) | ||
""" | ||
|
||
def __init__( | ||
self, | ||
moe_expert_capacity_factor: float = 1.0, | ||
moe_pad_expert_input_to_capacity: bool = True, | ||
): | ||
|
||
if moe_expert_capacity_factor < 0: | ||
moe_expert_capacity_factor = None | ||
self.moe_expert_capacity_factor = moe_expert_capacity_factor | ||
self.moe_pad_expert_input_to_capacity = moe_pad_expert_input_to_capacity | ||
|
||
def _set_cfgs(self, cfg): | ||
cfg.moe_expert_capacity_factor = self.moe_expert_capacity_factor | ||
cfg.moe_pad_expert_input_to_capacity = self.moe_pad_expert_input_to_capacity | ||
|
||
def setup(self, trainer: pl.Trainer, pl_module: pl.LightningModule, stage: str) -> None: | ||
assert isinstance(trainer.strategy, MegatronStrategy), "MegatronTokenDrop requires MegatronStrategy" | ||
if hasattr(trainer.model, "config") and isinstance(trainer.model.config, ModelParallelConfig): | ||
assert trainer.model.config.moe_token_dispatcher_type in [ | ||
"alltoall", | ||
"alltoall_seq", | ||
], 'moe_expert_capacity_factor only works with alltoall token dispatcher' | ||
assert trainer.model.config.moe_router_load_balancing_type in [ | ||
"aux_loss", | ||
"none", | ||
], 'moe_expert_capacity_factor only works with aux_loss or none load balancing' | ||
|
||
if self.moe_pad_expert_input_to_capacity: | ||
if self.moe_expert_capacity_factor is None: | ||
raise ValueError('moe_expert_capacity_factor must be set to use moe_pad_expert_input_to_capacity') | ||
|
||
self._set_cfgs(trainer.model.config) | ||
if hasattr(trainer.model, '__io__'): | ||
self._set_cfgs(trainer.model.__io__.config) |
Oops, something went wrong.