From df1bc6ab18296632bab61a7481767254d3fc511f Mon Sep 17 00:00:00 2001 From: Chien-Chin Huang Date: Fri, 31 Jan 2025 10:40:33 -0800 Subject: [PATCH 1/8] Update [ghstack-poisoned] --- torchtitan/config_manager.py | 16 ++++++ torchtitan/model_spec.py | 27 +++++++++ torchtitan/models/__init__.py | 87 ++++++++++++++++++++++++++--- torchtitan/models/llama/__init__.py | 17 +++++- torchtitan/models/llama/model.py | 5 +- train.py | 21 +++---- 6 files changed, 150 insertions(+), 23 deletions(-) create mode 100644 torchtitan/model_spec.py diff --git a/torchtitan/config_manager.py b/torchtitan/config_manager.py index d59e34bc..f5c01f78 100644 --- a/torchtitan/config_manager.py +++ b/torchtitan/config_manager.py @@ -5,6 +5,7 @@ # LICENSE file in the root directory of this source tree. import argparse +import importlib import sys from collections import defaultdict from typing import Tuple, Union @@ -375,6 +376,12 @@ def __init__(self): The default value is 'allgather'. """, ) + self.parser.add_argument( + "--experimental.model_module_path", + type=str, + default="", + help="", + ) self.parser.add_argument( "--training.mixed_precision_param", type=str, @@ -638,6 +645,15 @@ def parse_args(self, args_list: list = sys.argv[1:]): exp["pipeline_parallel_split_points"] ) + if ( + "experimental" in args_dict + and "model_module_path" in args_dict["experimental"] + and args_dict["experimental"]["model_module_path"] + ): + from torchtitan.models import add_model_spec_path + + add_model_spec_path(args_dict["experimental"]["model_module_path"]) + # override args dict with cmd_args cmd_args_dict = self._args_to_two_level_dict(cmd_args) for section, section_args in cmd_args_dict.items(): diff --git a/torchtitan/model_spec.py b/torchtitan/model_spec.py new file mode 100644 index 00000000..74cc69c2 --- /dev/null +++ b/torchtitan/model_spec.py @@ -0,0 +1,27 @@ +from dataclasses import dataclass +from typing import Callable, Dict, List, Protocol, Tuple, Type + +import torch.nn as nn +from torch.distributed.pipelining.schedules import _PipelineSchedule + +@dataclass +class BaseModelArgs: + _enforced: str = "This field is used to enforce all fields have defaults." + + +class ModelProtocol(Protocol): + def from_model_args(self, args: BaseModelArgs) -> nn.Module: + ... + + +@dataclass +class ModelSpec: + name: str + cls: Type[nn.Module] + config: Dict[str, BaseModelArgs] + # As for now, this is a string. So it will have to be built-in to the + # TorchTitan library. In the future, we can make this a defined class + # that can be extended like ModelSpec. + tokenizer: str + parallelize_fn: Callable[[nn.Module], None] + pipelining_fn: Callable[[nn.Module], Tuple[_PipelineSchedule, List[nn.Module]]] diff --git a/torchtitan/models/__init__.py b/torchtitan/models/__init__.py index c666b065..ee9bf801 100644 --- a/torchtitan/models/__init__.py +++ b/torchtitan/models/__init__.py @@ -4,14 +4,85 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -from torchtitan.models.llama import llama3_configs, Transformer +import importlib -models_config = { - "llama3": llama3_configs, -} +import os +import pkgutil +from typing import Dict, Set -model_name_to_cls = {"llama3": Transformer} +import torchtitan.models as models +from torchtitan.model_spec import ModelSpec -model_name_to_tokenizer = { - "llama3": "tiktoken", -} + +_model_specs_path: Set[str] = set() + + +def _load_module(path: str): + path = os.path.expanduser(path) + + # 1. Check if path is an existing file or directory path. + if os.path.exists(path): + if os.path.isdir(path): + init_file = os.path.join(path, "__init__.py") + if os.path.isfile(init_file): + return _load_module_from_init(path) + + raise ImportError( + f"Directory '{path}' is not a Python package because it does not " + "contain an __init__.py file." + ) + else: + raise ImportError(f"Path '{path}' is not a directory.") + + # 2. If not a valid path, assume it's a dotted module name. + return importlib.import_module(path) + + +def _load_module_from_init(path: str): + module_name = os.path.basename(os.path.normpath(path)) + init_file = os.path.join(path, "__init__.py") + + spec = importlib.util.spec_from_file_location(module_name, init_file) + if spec is None: + raise ImportError(f"Could not create spec from '{init_file}'") + + module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(module) + return module + + +for _, name, _ in pkgutil.iter_modules(models.__path__): + full_module_name = f"{models.__name__}.{name}" + _model_specs_path.add(full_module_name) + # model_module = importlib.import_module(full_module_name) + # load_spec_from_module(model_module) + + +def add_model_spec_path(path: str): + global _model_specs_path + _model_specs_path.add(path) + + +def build_model_specs() -> Dict[str, ModelSpec]: + """ + Load all model specs from the `models` package. + """ + global _model_specs_path + model_specs = {} + for path in _model_specs_path: + module = _load_module(path) + model_spec = getattr(module, "model_spec", None) + if model_spec is not None: + model_specs[model_spec.name] = model_spec + # We would like to just use `model_spec` but current torchtitan parallelize + # functions depend on ModelArgs and can cause circular imports. + # As a result, we have to use `build_model_spec` as a workaround. + build_model_spec = getattr(module, "build_model_spec", None) + if build_model_spec: + model_spec = build_model_spec() + model_specs[model_spec.name] = model_spec + + return model_specs + + +__all__ = [add_model_spec_path, build_model_specs] diff --git a/torchtitan/models/llama/__init__.py b/torchtitan/models/llama/__init__.py index 3bb430d2..e61538b9 100644 --- a/torchtitan/models/llama/__init__.py +++ b/torchtitan/models/llama/__init__.py @@ -6,9 +6,9 @@ # # Copyright (c) Meta Platforms, Inc. All Rights Reserved. +from torchtitan.model_spec import ModelSpec from torchtitan.models.llama.model import ModelArgs, Transformer -__all__ = ["Transformer"] llama3_configs = { "debugmodel": ModelArgs(dim=256, n_layers=8, n_heads=16, rope_theta=500000), @@ -40,3 +40,18 @@ rope_theta=500000, ), } + + +def build_model_spec() -> ModelSpec: + # Avoid circular import + from torchtitan.parallelisms.parallelize_llama import parallelize_llama + from torchtitan.parallelisms.pipeline_llama import pipeline_llama + + return ModelSpec( + name="llama3", + cls=Transformer, + config=llama3_configs, + tokenizer="tiktoken", + parallelize_fn=parallelize_llama, + pipelining_fn=pipeline_llama, + ) diff --git a/torchtitan/models/llama/model.py b/torchtitan/models/llama/model.py index 641ef6de..d60447e4 100644 --- a/torchtitan/models/llama/model.py +++ b/torchtitan/models/llama/model.py @@ -13,11 +13,12 @@ import torch import torch.nn.functional as F from torch import nn +from torchtitan.model_spec import BaseModelArgs, ModelProtocol from torchtitan.models.norms import build_norm @dataclass -class ModelArgs: +class ModelArgs(BaseModelArgs): dim: int = 4096 n_layers: int = 32 n_heads: int = 32 @@ -258,7 +259,7 @@ def init_weights(self, init_std: float): nn.init.trunc_normal_(linear.weight, mean=0.0, std=init_std) -class TransformerBlock(nn.Module): +class TransformerBlock(nn.Module, ModelProtocol): """ TransformerBlock Module diff --git a/train.py b/train.py index bac22772..7f7cf2e1 100644 --- a/train.py +++ b/train.py @@ -19,13 +19,9 @@ from torchtitan.float8 import Float8Handler from torchtitan.logging import init_logger, logger from torchtitan.metrics import build_device_memory_monitor, build_metric_logger -from torchtitan.models import model_name_to_cls, model_name_to_tokenizer, models_config +from torchtitan.models import build_model_specs from torchtitan.optimizer import build_lr_schedulers, build_optimizers -from torchtitan.parallelisms import ( - models_parallelize_fns, - models_pipelining_fns, - ParallelDims, -) +from torchtitan.parallelisms import ParallelDims from torchtitan.profiling import maybe_enable_memory_snapshot, maybe_enable_profiling from torchtitan.utils import device_module, device_type @@ -80,9 +76,10 @@ def main(job_config: JobConfig): world_mesh, device, job_config.training.seed, job_config.training.deterministic ) model_name = job_config.model.name + model_spec = build_model_specs()[model_name] # build tokenizer - tokenizer_type = model_name_to_tokenizer[model_name] + tokenizer_type = model_spec.tokenizer tokenizer = build_tokenizer(tokenizer_type, job_config.model.tokenizer_path) # build dataloader data_loader = build_hf_data_loader( @@ -96,8 +93,8 @@ def main(job_config: JobConfig): ) # build model (using meta init) - model_cls = model_name_to_cls[model_name] - model_config = models_config[model_name][job_config.model.flavor] + model_cls = model_spec.cls + model_config = model_spec.config[job_config.model.flavor] # set the model configs from training inputs: # 1. norm type to decide which norm layer to use # 2. vocab size from tokenizer @@ -151,7 +148,7 @@ def loss_fn(pred, labels): # apply parallelisms and initialization if parallel_dims.pp_enabled: # apply PT-D Pipeline Parallel - pp_schedule, model_parts = models_pipelining_fns[model_name]( + pp_schedule, model_parts = model_spec.pipelining_fn( model, pp_mesh, parallel_dims, job_config, device, model_config, loss_fn ) # when PP is enabled, `model` obj is no longer used after this point, model_parts is used instead @@ -162,14 +159,14 @@ def loss_fn(pred, labels): # optimizer, and checkpointing for m in model_parts: # apply SPMD-style PT-D techniques - models_parallelize_fns[model_name](m, world_mesh, parallel_dims, job_config) + model_spec.parallelize_fn(m, world_mesh, parallel_dims, job_config) m.to_empty(device=init_device) with torch.no_grad(): m.init_weights(buffer_device=buffer_device) m.train() else: # apply PT-D Tensor Parallel, activation checkpointing, torch.compile, Data Parallel - models_parallelize_fns[model_name](model, world_mesh, parallel_dims, job_config) + model_spec.parallelize_fn(model, world_mesh, parallel_dims, job_config) model.to_empty(device=init_device) with torch.no_grad(): model.init_weights(buffer_device=buffer_device) From dfc1649a9d79244ceb83a7923c05717f42a6231b Mon Sep 17 00:00:00 2001 From: Chien-Chin Huang Date: Fri, 31 Jan 2025 10:46:46 -0800 Subject: [PATCH 2/8] Update [ghstack-poisoned] --- torchtitan/config_manager.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/torchtitan/config_manager.py b/torchtitan/config_manager.py index f5c01f78..9741e42a 100644 --- a/torchtitan/config_manager.py +++ b/torchtitan/config_manager.py @@ -5,7 +5,6 @@ # LICENSE file in the root directory of this source tree. import argparse -import importlib import sys from collections import defaultdict from typing import Tuple, Union @@ -377,10 +376,18 @@ def __init__(self): """, ) self.parser.add_argument( - "--experimental.model_module_path", + "--experimental.custom_model_path", type=str, default="", - help="", + help=""" + The --custom_model_path option allows to specify a custom path to a model module + + that is not natively implemented within TorchTitan. + + Acceptable values are the file system path to the module (e.g., my_models/model_x) + + dotted import module (e.g., some_package.model_x). + """ ) self.parser.add_argument( "--training.mixed_precision_param", From 720f12a51bfa96727a8b118cf6228a529cc7daa4 Mon Sep 17 00:00:00 2001 From: Chien-Chin Huang Date: Fri, 31 Jan 2025 10:49:46 -0800 Subject: [PATCH 3/8] Update [ghstack-poisoned] --- torchtitan/config_manager.py | 2 +- torchtitan/model_spec.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/torchtitan/config_manager.py b/torchtitan/config_manager.py index 9741e42a..7b8684e8 100644 --- a/torchtitan/config_manager.py +++ b/torchtitan/config_manager.py @@ -387,7 +387,7 @@ def __init__(self): Acceptable values are the file system path to the module (e.g., my_models/model_x) dotted import module (e.g., some_package.model_x). - """ + """, ) self.parser.add_argument( "--training.mixed_precision_param", diff --git a/torchtitan/model_spec.py b/torchtitan/model_spec.py index 74cc69c2..08efc7be 100644 --- a/torchtitan/model_spec.py +++ b/torchtitan/model_spec.py @@ -4,14 +4,14 @@ import torch.nn as nn from torch.distributed.pipelining.schedules import _PipelineSchedule + @dataclass class BaseModelArgs: _enforced: str = "This field is used to enforce all fields have defaults." class ModelProtocol(Protocol): - def from_model_args(self, args: BaseModelArgs) -> nn.Module: - ... + def from_model_args(self, args: BaseModelArgs) -> nn.Module: ... @dataclass From 225bfcc371b64b56794d43fa801f7c71924452a3 Mon Sep 17 00:00:00 2001 From: Chien-Chin Huang Date: Fri, 31 Jan 2025 10:58:16 -0800 Subject: [PATCH 4/8] Update [ghstack-poisoned] --- torchtitan/model_spec.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/torchtitan/model_spec.py b/torchtitan/model_spec.py index 08efc7be..5ac9ea8b 100644 --- a/torchtitan/model_spec.py +++ b/torchtitan/model_spec.py @@ -11,7 +11,8 @@ class BaseModelArgs: class ModelProtocol(Protocol): - def from_model_args(self, args: BaseModelArgs) -> nn.Module: ... + def from_model_args(self, args: BaseModelArgs) -> nn.Module: + ... @dataclass From 650152e7b92173c2f43dc323cc530613ad6b8b64 Mon Sep 17 00:00:00 2001 From: Chien-Chin Huang Date: Fri, 31 Jan 2025 11:00:17 -0800 Subject: [PATCH 5/8] Update [ghstack-poisoned] --- torchtitan/model_spec.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/torchtitan/model_spec.py b/torchtitan/model_spec.py index 5ac9ea8b..28c050b4 100644 --- a/torchtitan/model_spec.py +++ b/torchtitan/model_spec.py @@ -1,3 +1,12 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +# +# Copyright (c) Meta Platforms, Inc. All Rights Reserved. + + from dataclasses import dataclass from typing import Callable, Dict, List, Protocol, Tuple, Type From 6a51325f359b6a0a88583aff8ec9f4ff67d66e6e Mon Sep 17 00:00:00 2001 From: Chien-Chin Huang Date: Thu, 6 Feb 2025 15:27:32 -0800 Subject: [PATCH 6/8] Update [ghstack-poisoned] --- torchtitan/models/llama/model.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torchtitan/models/llama/model.py b/torchtitan/models/llama/model.py index d60447e4..67c54260 100644 --- a/torchtitan/models/llama/model.py +++ b/torchtitan/models/llama/model.py @@ -259,7 +259,7 @@ def init_weights(self, init_std: float): nn.init.trunc_normal_(linear.weight, mean=0.0, std=init_std) -class TransformerBlock(nn.Module, ModelProtocol): +class TransformerBlock(nn.Module): """ TransformerBlock Module @@ -332,7 +332,7 @@ def init_weights(self): self.feed_forward.init_weights(self.weight_init_std) -class Transformer(nn.Module): +class Transformer(nn.Module, ModelProtocol): """ Transformer Module From 2e569d7d15811747d48d2e981b64eedc3702ddea Mon Sep 17 00:00:00 2001 From: Chien-Chin Huang Date: Thu, 6 Feb 2025 15:41:35 -0800 Subject: [PATCH 7/8] Update [ghstack-poisoned] --- torchtitan/config_manager.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/torchtitan/config_manager.py b/torchtitan/config_manager.py index e4c58d03..0cda86b4 100644 --- a/torchtitan/config_manager.py +++ b/torchtitan/config_manager.py @@ -405,11 +405,8 @@ def __init__(self): default="", help=""" The --custom_model_path option allows to specify a custom path to a model module - that is not natively implemented within TorchTitan. - Acceptable values are the file system path to the module (e.g., my_models/model_x) - dotted import module (e.g., some_package.model_x). """, ) From bab9bf5b85623ab0636e2fed1be12b5517199511 Mon Sep 17 00:00:00 2001 From: Chien-Chin Huang Date: Thu, 6 Feb 2025 16:11:12 -0800 Subject: [PATCH 8/8] Update [ghstack-poisoned] --- torchtitan/__init__.py | 12 ++++++++++++ torchtitan/optimizer.py | 1 - torchtitan/utils.py | 19 +++++++++---------- 3 files changed, 21 insertions(+), 11 deletions(-) create mode 100644 torchtitan/__init__.py diff --git a/torchtitan/__init__.py b/torchtitan/__init__.py new file mode 100644 index 00000000..d39e084f --- /dev/null +++ b/torchtitan/__init__.py @@ -0,0 +1,12 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +# +# Copyright (c) Meta Platforms, Inc. All Rights Reserved. + +# Import the built-in models here so that the corresponding register_model_spec() +# will be called. +import torchtitan.models + diff --git a/torchtitan/optimizer.py b/torchtitan/optimizer.py index 631cbea5..9c8d749c 100644 --- a/torchtitan/optimizer.py +++ b/torchtitan/optimizer.py @@ -5,7 +5,6 @@ # LICENSE file in the root directory of this source tree. import functools -from abc import ABC from typing import Any, Callable, Dict, Iterable, List import torch diff --git a/torchtitan/utils.py b/torchtitan/utils.py index 8e43618e..8ff0cd2d 100644 --- a/torchtitan/utils.py +++ b/torchtitan/utils.py @@ -439,17 +439,16 @@ def import_module_from_path(path: str): # 1. Check if path is an existing file or directory path. if os.path.exists(path): - if os.path.isdir(path): - init_file = os.path.join(path, "__init__.py") - if os.path.isfile(init_file): - return _import_module_from_init(path) - - raise ImportError( - f"Directory '{path}' is not a Python package because it does not " - "contain an __init__.py file." - ) - else: + if not os.path.isdir(path): raise ImportError(f"Path '{path}' is not a directory.") + init_file = os.path.join(path, "__init__.py") + if os.path.isfile(init_file): + return _import_module_from_init(path) + + raise ImportError( + f"Directory '{path}' is not a Python package because it does not " + "contain an __init__.py file." + ) # 2. If not a valid path, assume it's a dotted module name. return importlib.import_module(path)