Skip to content

Commit

Permalink
Add cli_factory decorator
Browse files Browse the repository at this point in the history
Signed-off-by: Alexandros Koumparoulis <akoumparouli@nvidia.com>
  • Loading branch information
akoumpa committed Oct 15, 2024
1 parent eac0634 commit dac970e
Show file tree
Hide file tree
Showing 25 changed files with 105 additions and 62 deletions.
10 changes: 7 additions & 3 deletions nemo/collections/llm/recipes/ADD-RECIPE.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,12 @@ Create a new Python file in the `nemo/collections/llm/recipes/` directory. Name
Create a function called `model` to define the model configuration:

```python
from nemo.lightning.run import cli_factory

NAME = "my_new_model_12b"


@run.cli.factory(name=NAME)
@cli_factory(name=NAME)
def model() -> run.Config[pl.LightningModule]:
return run.Config(YourModel, config=run.Config(YourModelConfig))
```
Expand Down Expand Up @@ -49,8 +51,9 @@ Create a function called `pretrain_recipe` or `finetune_recipe` to define the re

```python
from nemo.collections.llm import pretrain
from nemo.lightning.run import cli_factory

@run.cli.factory(target=pretrain, name=NAME)
@cli_factory(target=pretrain, name=NAME)
def pretrain_recipe(
# Add other parameters as needed
) -> run.Config[nl.PretrainRecipe]:
Expand All @@ -68,8 +71,9 @@ def pretrain_recipe(

```python
from nemo.collections.llm import finetune
from nemo.lightning.run import cli_factory

@run.cli.factory(target=finetune, name=NAME)
@cli_factory(target=finetune, name=NAME)
def finetune_recipe(
# Add other parameters as needed
) -> run.Config[nl.FinetuneRecipe]:
Expand Down
5 changes: 3 additions & 2 deletions nemo/collections/llm/recipes/llama31_405b.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,12 +27,13 @@
from nemo.collections.llm.recipes.log.default import default_log, default_resume, tensorboard_logger
from nemo.collections.llm.recipes.optim.adam import distributed_fused_adam_with_cosine_annealing
from nemo.collections.llm.recipes.precision.mixed_precision import bf16_mixed
from nemo.lightning.run import cli_factory
from nemo.utils.exp_manager import TimingCallback

NAME = "llama31_405b"


@run.cli.factory(name=NAME)
@cli_factory(name=NAME)
def model() -> run.Config[pl.LightningModule]:
"""
Factory function to create a Llama3.1 405B model configuration.
Expand Down Expand Up @@ -129,7 +130,7 @@ def trainer(
return trainer


@run.cli.factory(target=pretrain, name=NAME)
@cli_factory(target=pretrain, name=NAME)
def pretrain_recipe(
dir: Optional[str] = None, name: str = "default", num_nodes: int = 1, num_gpus_per_node: int = 8, fn=pretrain
) -> run.Partial:
Expand Down
9 changes: 5 additions & 4 deletions nemo/collections/llm/recipes/llama3_70b.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,12 +33,13 @@
from nemo.collections.llm.recipes.precision.mixed_precision import bf16_mixed
from nemo.collections.llm.recipes.tp_overlap_configs.userbuffers import userbuffers_bf16_h100_h8192_tp4_mbs1_seqlen8192
from nemo.lightning.pytorch.callbacks.megatron_comm_overlap import MegatronCommOverlapCallback
from nemo.lightning.run import cli_factory
from nemo.utils.exp_manager import TimingCallback

NAME = "llama3_70b"


@run.cli.factory(name=NAME)
@cli_factory(name=NAME)
def model() -> run.Config[pl.LightningModule]:
"""
Factory function to create a Llama3 70B model configuration.
Expand Down Expand Up @@ -140,7 +141,7 @@ def trainer(
return trainer


@run.cli.factory(target=pretrain, name=NAME)
@cli_factory(target=pretrain, name=NAME)
def pretrain_recipe(
dir: Optional[str] = None, name: str = "default", num_nodes: int = 1, num_gpus_per_node: int = 8, fn=pretrain
) -> run.Partial:
Expand Down Expand Up @@ -187,7 +188,7 @@ def pretrain_recipe(
)


@run.cli.factory(target=pretrain, name=NAME + "_performance")
@cli_factory(target=pretrain, name=NAME + "_performance")
def pretrain_recipe_performance(
dir: Optional[str] = None, name: str = "default", num_nodes: int = 1, num_gpus_per_node: int = 8, fn=pretrain
) -> run.Partial:
Expand Down Expand Up @@ -234,7 +235,7 @@ def pretrain_recipe_performance(
return recipe


@run.cli.factory(target=finetune, name=NAME)
@cli_factory(target=finetune, name=NAME)
def finetune_recipe(
dir: Optional[str] = None,
name: str = "default",
Expand Down
5 changes: 3 additions & 2 deletions nemo/collections/llm/recipes/llama3_70b_16k.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,12 @@
from nemo.collections.llm.gpt.data.mock import MockDataModule
from nemo.collections.llm.gpt.data.squad import SquadDataModule
from nemo.collections.llm.recipes import llama3_70b
from nemo.lightning.run import cli_factory

NAME = "llama3_70b_16k"


@run.cli.factory(name=NAME)
@cli_factory(name=NAME)
def model() -> run.Config[pl.LightningModule]:
"""
Factory function to create a Llama3 70B model configuration with 16k sequence length.
Expand Down Expand Up @@ -87,7 +88,7 @@ def trainer(
)


@run.cli.factory(target=pretrain, name=NAME)
@cli_factory(target=pretrain, name=NAME)
def pretrain_recipe(
dir: Optional[str] = None,
name: str = "default",
Expand Down
5 changes: 3 additions & 2 deletions nemo/collections/llm/recipes/llama3_70b_64k.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,13 @@
from nemo.collections.llm.gpt.data.mock import MockDataModule
from nemo.collections.llm.gpt.data.squad import SquadDataModule
from nemo.collections.llm.recipes import llama3_70b
from nemo.lightning.run import cli_factory
from nemo.utils.exp_manager import TimingCallback

NAME = "llama3_70b_64k"


@run.cli.factory(name=NAME)
@cli_factory(name=NAME)
def model() -> run.Config[pl.LightningModule]:
"""
Factory function to create a Llama3 70B model configuration with 64k sequence length.
Expand Down Expand Up @@ -90,7 +91,7 @@ def trainer(
)


@run.cli.factory(target=pretrain, name=NAME)
@cli_factory(target=pretrain, name=NAME)
def pretrain_recipe(
dir: Optional[str] = None,
name: str = "default",
Expand Down
9 changes: 5 additions & 4 deletions nemo/collections/llm/recipes/llama3_8b.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,12 +32,13 @@
from nemo.collections.llm.recipes.optim.adam import distributed_fused_adam_with_cosine_annealing
from nemo.collections.llm.recipes.precision.mixed_precision import bf16_mixed
from nemo.lightning.pytorch.callbacks.megatron_comm_overlap import MegatronCommOverlapCallback
from nemo.lightning.run import cli_factory
from nemo.utils.exp_manager import TimingCallback

NAME = "llama3_8b"


@run.cli.factory(name=NAME)
@cli_factory(name=NAME)
def model() -> run.Config[pl.LightningModule]:
"""
Factory function to create a Llama3 8B model configuration.
Expand Down Expand Up @@ -140,7 +141,7 @@ def trainer(
return trainer


@run.cli.factory(target=pretrain, name=NAME)
@cli_factory(target=pretrain, name=NAME)
def pretrain_recipe(
dir: Optional[str] = None, name: str = "default", num_nodes: int = 1, num_gpus_per_node: int = 8, fn=pretrain
) -> run.Partial:
Expand Down Expand Up @@ -188,7 +189,7 @@ def pretrain_recipe(
)


@run.cli.factory(target=pretrain, name=NAME + "_optimized")
@cli_factory(target=pretrain, name=NAME + "_optimized")
def pretrain_recipe_performance(
dir: Optional[str] = None,
name: str = "default",
Expand Down Expand Up @@ -234,7 +235,7 @@ def pretrain_recipe_performance(
return recipe


@run.cli.factory(target=finetune, name=NAME)
@cli_factory(target=finetune, name=NAME)
def finetune_recipe(
dir: Optional[str] = None,
name: str = "default",
Expand Down
5 changes: 3 additions & 2 deletions nemo/collections/llm/recipes/llama3_8b_16k.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,12 @@
from nemo.collections.llm.gpt.data.mock import MockDataModule
from nemo.collections.llm.gpt.data.squad import SquadDataModule
from nemo.collections.llm.recipes import llama3_8b
from nemo.lightning.run import cli_factory

NAME = "llama3_8b_16k"


@run.cli.factory(name=NAME)
@cli_factory(name=NAME)
def model() -> run.Config[pl.LightningModule]:
"""
Factory function to create a Llama3 8B model configuration with 16k sequence length.
Expand Down Expand Up @@ -87,7 +88,7 @@ def trainer(
)


@run.cli.factory(target=pretrain, name=NAME)
@cli_factory(target=pretrain, name=NAME)
def pretrain_recipe(
dir: Optional[str] = None,
name: str = "default",
Expand Down
5 changes: 3 additions & 2 deletions nemo/collections/llm/recipes/llama3_8b_64k.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,12 @@
from nemo.collections.llm.gpt.data.mock import MockDataModule
from nemo.collections.llm.gpt.data.squad import SquadDataModule
from nemo.collections.llm.recipes import llama3_8b
from nemo.lightning.run import cli_factory

NAME = "llama3_8b_64k"


@run.cli.factory(name=NAME)
@cli_factory(name=NAME)
def model() -> run.Config[pl.LightningModule]:
"""
Factory function to create a Llama3 8B model configuration with 64k sequence length.
Expand Down Expand Up @@ -87,7 +88,7 @@ def trainer(
)


@run.cli.factory(target=pretrain, name=NAME)
@cli_factory(target=pretrain, name=NAME)
def pretrain_recipe(
dir: Optional[str] = None,
name: str = "default",
Expand Down
7 changes: 4 additions & 3 deletions nemo/collections/llm/recipes/mistral.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,11 +32,12 @@
from nemo.collections.llm.recipes.optim.adam import distributed_fused_adam_with_cosine_annealing
from nemo.collections.llm.recipes.precision.mixed_precision import bf16_mixed
from nemo.utils.exp_manager import TimingCallback
from nemo.lightning.run import cli_factory

NAME = "mistral"


@run.cli.factory(name=NAME)
@cli_factory(name=NAME)
def model() -> run.Config[pl.LightningModule]:
"""
Factory function to create a Mistral 7B model configuration.
Expand Down Expand Up @@ -136,7 +137,7 @@ def trainer(
return trainer


@run.cli.factory(target=pretrain, name=NAME)
@cli_factory(target=pretrain, name=NAME)
def pretrain_recipe(
dir: Optional[str] = None, name: str = "default", num_nodes: int = 1, num_gpus_per_node: int = 8, fn=pretrain
) -> run.Partial:
Expand Down Expand Up @@ -186,7 +187,7 @@ def pretrain_recipe(
)


@run.cli.factory(target=finetune, name=NAME)
@cli_factory(target=finetune, name=NAME)
def finetune_recipe(
dir: Optional[str] = None,
name: str = "default",
Expand Down
9 changes: 5 additions & 4 deletions nemo/collections/llm/recipes/mixtral_8x22b.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,11 +33,12 @@
from nemo.lightning.pytorch.callbacks.megatron_comm_overlap import MegatronCommOverlapCallback
from nemo.lightning.pytorch.callbacks.moe_token_drop import MegatronTokenDropCallback
from nemo.utils.exp_manager import TimingCallback
from nemo.lightning.run import cli_factory

NAME = "mixtral_8x22b"


@run.cli.factory(name=NAME)
@cli_factory(name=NAME)
def model() -> run.Config[pl.LightningModule]:
"""
Factory function to create a Mixtral 8x22B model configuration.
Expand Down Expand Up @@ -140,7 +141,7 @@ def trainer(
return trainer


@run.cli.factory(target=pretrain, name=NAME)
@cli_factory(target=pretrain, name=NAME)
def pretrain_recipe(
dir: Optional[str] = None, name: str = "default", num_nodes: int = 16, num_gpus_per_node: int = 8, fn=pretrain
) -> run.Partial:
Expand Down Expand Up @@ -182,7 +183,7 @@ def pretrain_recipe(
)


@run.cli.factory(target=pretrain, name=NAME + "_performance")
@cli_factory(target=pretrain, name=NAME + "_performance")
def pretrain_recipe_performance(
dir: Optional[str] = None, name: str = "default", num_nodes: int = 8, num_gpus_per_node: int = 8, fn=pretrain
) -> run.Partial:
Expand Down Expand Up @@ -225,7 +226,7 @@ def pretrain_recipe_performance(
return recipe


@run.cli.factory(target=finetune, name=NAME)
@cli_factory(target=finetune, name=NAME)
def finetune_recipe(
dir: Optional[str] = None,
name: str = "default",
Expand Down
9 changes: 5 additions & 4 deletions nemo/collections/llm/recipes/mixtral_8x7b.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,11 +33,12 @@
from nemo.lightning.pytorch.callbacks.megatron_comm_overlap import MegatronCommOverlapCallback
from nemo.lightning.pytorch.callbacks.moe_token_drop import MegatronTokenDropCallback
from nemo.utils.exp_manager import TimingCallback
from nemo.lightning.run import cli_factory

NAME = "mixtral_8x7b"


@run.cli.factory(name=NAME)
@cli_factory(name=NAME)
def model() -> run.Config[pl.LightningModule]:
"""
Factory function to create a Mixtral 8x7B model configuration.
Expand Down Expand Up @@ -139,7 +140,7 @@ def trainer(
return trainer


@run.cli.factory(target=pretrain, name=NAME)
@cli_factory(target=pretrain, name=NAME)
def pretrain_recipe(
dir: Optional[str] = None, name: str = "default", num_nodes: int = 8, num_gpus_per_node: int = 8, fn=pretrain
) -> run.Partial:
Expand Down Expand Up @@ -181,7 +182,7 @@ def pretrain_recipe(
)


@run.cli.factory(target=pretrain, name=NAME + "_performance")
@cli_factory(target=pretrain, name=NAME + "_performance")
def pretrain_recipe_performance(
dir: Optional[str] = None, name: str = "default", num_nodes: int = 8, num_gpus_per_node: int = 8, fn=pretrain
) -> run.Partial:
Expand Down Expand Up @@ -224,7 +225,7 @@ def pretrain_recipe_performance(
return recipe


@run.cli.factory(target=finetune, name=NAME)
@cli_factory(target=finetune, name=NAME)
def finetune_recipe(
dir: Optional[str] = None,
name: str = "default",
Expand Down
5 changes: 3 additions & 2 deletions nemo/collections/llm/recipes/mixtral_8x7b_16k.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,12 @@
from nemo.collections.llm.gpt.data.squad import SquadDataModule
from nemo.collections.llm.recipes import mixtral_8x7b
from nemo.utils.exp_manager import TimingCallback
from nemo.lightning.run import cli_factory

NAME = "mixtral_8x7b_16k"


@run.cli.factory(name=NAME)
@cli_factory(name=NAME)
def model() -> run.Config[pl.LightningModule]:
"""
Factory function to create a Mixtral 8x7B model configuration with 16k sequence length.
Expand Down Expand Up @@ -91,7 +92,7 @@ def trainer(
)


@run.cli.factory(target=pretrain, name=NAME)
@cli_factory(target=pretrain, name=NAME)
def pretrain_recipe(
dir: Optional[str] = None,
name: str = "default",
Expand Down
5 changes: 3 additions & 2 deletions nemo/collections/llm/recipes/mixtral_8x7b_64k.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,12 @@
from nemo.collections.llm.gpt.data.squad import SquadDataModule
from nemo.collections.llm.recipes import mixtral_8x7b
from nemo.utils.exp_manager import TimingCallback
from nemo.lightning.run import cli_factory

NAME = "mixtral_8x7b_64k"


@run.cli.factory(name=NAME)
@cli_factory(name=NAME)
def model() -> run.Config[pl.LightningModule]:
"""
Factory function to create a Mixtral 8x7B model configuration with 64k sequence length.
Expand Down Expand Up @@ -91,7 +92,7 @@ def trainer(
)


@run.cli.factory(target=pretrain, name=NAME)
@cli_factory(target=pretrain, name=NAME)
def pretrain_recipe(
dir: Optional[str] = None,
name: str = "default",
Expand Down
Loading

0 comments on commit dac970e

Please sign in to comment.