-
Notifications
You must be signed in to change notification settings - Fork 259
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add Dynamic Model Import and ModelSpec Definition #814
base: gh/fegin/8/base
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -375,6 +375,20 @@ def __init__(self): | |
The default value is 'allgather'. | ||
""", | ||
) | ||
self.parser.add_argument( | ||
"--experimental.custom_model_path", | ||
type=str, | ||
default="", | ||
help=""" | ||
The --custom_model_path option allows to specify a custom path to a model module | ||
|
||
that is not natively implemented within TorchTitan. | ||
|
||
Acceptable values are the file system path to the module (e.g., my_models/model_x) | ||
|
||
dotted import module (e.g., some_package.model_x). | ||
""", | ||
) | ||
self.parser.add_argument( | ||
"--training.mixed_precision_param", | ||
type=str, | ||
|
@@ -638,6 +652,15 @@ def parse_args(self, args_list: list = sys.argv[1:]): | |
exp["pipeline_parallel_split_points"] | ||
) | ||
|
||
if ( | ||
"experimental" in args_dict | ||
and "model_module_path" in args_dict["experimental"] | ||
and args_dict["experimental"]["model_module_path"] | ||
): | ||
from torchtitan.models import add_model_spec_path | ||
|
||
add_model_spec_path(args_dict["experimental"]["model_module_path"]) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. may I ask why putting it here, instead of in |
||
|
||
# override args dict with cmd_args | ||
cmd_args_dict = self._args_to_two_level_dict(cmd_args) | ||
for section, section_args in cmd_args_dict.items(): | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,37 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# All rights reserved. | ||
# | ||
# This source code is licensed under the BSD-style license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
# | ||
# Copyright (c) Meta Platforms, Inc. All Rights Reserved. | ||
|
||
|
||
from dataclasses import dataclass | ||
from typing import Callable, Dict, List, Protocol, Tuple, Type | ||
|
||
import torch.nn as nn | ||
from torch.distributed.pipelining.schedules import _PipelineSchedule | ||
|
||
|
||
@dataclass | ||
class BaseModelArgs: | ||
_enforced: str = "This field is used to enforce all fields have defaults." | ||
|
||
|
||
class ModelProtocol(Protocol): | ||
def from_model_args(self, args: BaseModelArgs) -> nn.Module: | ||
... | ||
|
||
|
||
@dataclass | ||
class ModelSpec: | ||
name: str | ||
cls: Type[nn.Module] | ||
config: Dict[str, BaseModelArgs] | ||
# As for now, this is a string. So it will have to be built-in to the | ||
# TorchTitan library. In the future, we can make this a defined class | ||
# that can be extended like ModelSpec. | ||
tokenizer: str | ||
parallelize_fn: Callable[[nn.Module], None] | ||
pipelining_fn: Callable[[nn.Module], Tuple[_PipelineSchedule, List[nn.Module]]] |
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -4,14 +4,85 @@ | |||||
# This source code is licensed under the BSD-style license found in the | ||||||
# LICENSE file in the root directory of this source tree. | ||||||
|
||||||
from torchtitan.models.llama import llama3_configs, Transformer | ||||||
import importlib | ||||||
|
||||||
models_config = { | ||||||
"llama3": llama3_configs, | ||||||
} | ||||||
import os | ||||||
import pkgutil | ||||||
from typing import Dict, Set | ||||||
|
||||||
model_name_to_cls = {"llama3": Transformer} | ||||||
import torchtitan.models as models | ||||||
from torchtitan.model_spec import ModelSpec | ||||||
|
||||||
model_name_to_tokenizer = { | ||||||
"llama3": "tiktoken", | ||||||
} | ||||||
|
||||||
_model_specs_path: Set[str] = set() | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit
Suggested change
|
||||||
|
||||||
|
||||||
def _load_module(path: str): | ||||||
path = os.path.expanduser(path) | ||||||
|
||||||
# 1. Check if path is an existing file or directory path. | ||||||
if os.path.exists(path): | ||||||
if os.path.isdir(path): | ||||||
init_file = os.path.join(path, "__init__.py") | ||||||
if os.path.isfile(init_file): | ||||||
return _load_module_from_init(path) | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. sorry if this is noob question: don't we need to put this function before
Comment on lines
+24
to
+28
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. maybe let's do the assert first to avoid a nested if for better readability? |
||||||
|
||||||
raise ImportError( | ||||||
f"Directory '{path}' is not a Python package because it does not " | ||||||
"contain an __init__.py file." | ||||||
) | ||||||
else: | ||||||
raise ImportError(f"Path '{path}' is not a directory.") | ||||||
|
||||||
# 2. If not a valid path, assume it's a dotted module name. | ||||||
return importlib.import_module(path) | ||||||
|
||||||
|
||||||
def _load_module_from_init(path: str): | ||||||
module_name = os.path.basename(os.path.normpath(path)) | ||||||
init_file = os.path.join(path, "__init__.py") | ||||||
|
||||||
spec = importlib.util.spec_from_file_location(module_name, init_file) | ||||||
if spec is None: | ||||||
raise ImportError(f"Could not create spec from '{init_file}'") | ||||||
Comment on lines
+46
to
+47
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ditto: directly assert? |
||||||
|
||||||
module = importlib.util.module_from_spec(spec) | ||||||
spec.loader.exec_module(module) | ||||||
return module | ||||||
|
||||||
|
||||||
for _, name, _ in pkgutil.iter_modules(models.__path__): | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. why do we have a global for loop here? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think this is for importing the models in |
||||||
full_module_name = f"{models.__name__}.{name}" | ||||||
_model_specs_path.add(full_module_name) | ||||||
# model_module = importlib.import_module(full_module_name) | ||||||
# load_spec_from_module(model_module) | ||||||
|
||||||
|
||||||
def add_model_spec_path(path: str): | ||||||
global _model_specs_path | ||||||
_model_specs_path.add(path) | ||||||
|
||||||
|
||||||
def build_model_specs() -> Dict[str, ModelSpec]: | ||||||
""" | ||||||
Load all model specs from the `models` package. | ||||||
""" | ||||||
global _model_specs_path | ||||||
model_specs = {} | ||||||
for path in _model_specs_path: | ||||||
module = _load_module(path) | ||||||
model_spec = getattr(module, "model_spec", None) | ||||||
if model_spec is not None: | ||||||
model_specs[model_spec.name] = model_spec | ||||||
# We would like to just use `model_spec` but current torchtitan parallelize | ||||||
# functions depend on ModelArgs and can cause circular imports. | ||||||
# As a result, we have to use `build_model_spec` as a workaround. | ||||||
build_model_spec = getattr(module, "build_model_spec", None) | ||||||
if build_model_spec: | ||||||
model_spec = build_model_spec() | ||||||
model_specs[model_spec.name] = model_spec | ||||||
|
||||||
return model_specs | ||||||
|
||||||
|
||||||
__all__ = [add_model_spec_path, build_model_specs] |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -6,9 +6,9 @@ | |
# | ||
# Copyright (c) Meta Platforms, Inc. All Rights Reserved. | ||
|
||
from torchtitan.model_spec import ModelSpec | ||
from torchtitan.models.llama.model import ModelArgs, Transformer | ||
|
||
__all__ = ["Transformer"] | ||
|
||
llama3_configs = { | ||
"debugmodel": ModelArgs(dim=256, n_layers=8, n_heads=16, rope_theta=500000), | ||
|
@@ -40,3 +40,18 @@ | |
rope_theta=500000, | ||
), | ||
} | ||
|
||
|
||
def build_model_spec() -> ModelSpec: | ||
# Avoid circular import | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I suggest we do restructuring of the repo to make it more logical. E.g. below is an example, can be renamed The original You'd avoid circular import, if we put There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please add unit tests, and documentation with examples in |
||
from torchtitan.parallelisms.parallelize_llama import parallelize_llama | ||
from torchtitan.parallelisms.pipeline_llama import pipeline_llama | ||
|
||
return ModelSpec( | ||
name="llama3", | ||
cls=Transformer, | ||
config=llama3_configs, | ||
tokenizer="tiktoken", | ||
parallelize_fn=parallelize_llama, | ||
pipelining_fn=pipeline_llama, | ||
) |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -13,11 +13,12 @@ | |
import torch | ||
import torch.nn.functional as F | ||
from torch import nn | ||
from torchtitan.model_spec import BaseModelArgs, ModelProtocol | ||
from torchtitan.models.norms import build_norm | ||
|
||
|
||
@dataclass | ||
class ModelArgs: | ||
class ModelArgs(BaseModelArgs): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Down the road we will have many models, like MM model. Do we want all model args to inherit this? Currently we use different model args for different model arch. |
||
dim: int = 4096 | ||
n_layers: int = 32 | ||
n_heads: int = 32 | ||
|
@@ -258,7 +259,7 @@ def init_weights(self, init_std: float): | |
nn.init.trunc_normal_(linear.weight, mean=0.0, std=init_std) | ||
|
||
|
||
class TransformerBlock(nn.Module): | ||
class TransformerBlock(nn.Module, ModelProtocol): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. should be on |
||
""" | ||
TransformerBlock Module | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: is this expected?