Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Dynamic Model Import and ModelSpec Definition #814

Open
wants to merge 5 commits into
base: gh/fegin/8/base
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 23 additions & 0 deletions torchtitan/config_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -375,6 +375,20 @@ def __init__(self):
The default value is 'allgather'.
""",
)
self.parser.add_argument(
"--experimental.custom_model_path",
type=str,
default="",
help="""
The --custom_model_path option allows to specify a custom path to a model module

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: is this expected?

that is not natively implemented within TorchTitan.

Acceptable values are the file system path to the module (e.g., my_models/model_x)

dotted import module (e.g., some_package.model_x).
""",
)
self.parser.add_argument(
"--training.mixed_precision_param",
type=str,
Expand Down Expand Up @@ -638,6 +652,15 @@ def parse_args(self, args_list: list = sys.argv[1:]):
exp["pipeline_parallel_split_points"]
)

if (
"experimental" in args_dict
and "model_module_path" in args_dict["experimental"]
and args_dict["experimental"]["model_module_path"]
):
from torchtitan.models import add_model_spec_path

add_model_spec_path(args_dict["experimental"]["model_module_path"])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

may I ask why putting it here, instead of in build_model_specs? Is it because you think it's better to fail early? I think in general in torchtitan, we are following the idea that we try to put fail check close to where something is being used, whose main benefit is that the checking and functioning are less scattered.


# override args dict with cmd_args
cmd_args_dict = self._args_to_two_level_dict(cmd_args)
for section, section_args in cmd_args_dict.items():
Expand Down
37 changes: 37 additions & 0 deletions torchtitan/model_spec.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
#
# Copyright (c) Meta Platforms, Inc. All Rights Reserved.


from dataclasses import dataclass
from typing import Callable, Dict, List, Protocol, Tuple, Type

import torch.nn as nn
from torch.distributed.pipelining.schedules import _PipelineSchedule


@dataclass
class BaseModelArgs:
_enforced: str = "This field is used to enforce all fields have defaults."


class ModelProtocol(Protocol):
def from_model_args(self, args: BaseModelArgs) -> nn.Module:
...


@dataclass
class ModelSpec:
name: str
cls: Type[nn.Module]
config: Dict[str, BaseModelArgs]
# As for now, this is a string. So it will have to be built-in to the
# TorchTitan library. In the future, we can make this a defined class
# that can be extended like ModelSpec.
tokenizer: str
parallelize_fn: Callable[[nn.Module], None]
pipelining_fn: Callable[[nn.Module], Tuple[_PipelineSchedule, List[nn.Module]]]
87 changes: 79 additions & 8 deletions torchtitan/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,85 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from torchtitan.models.llama import llama3_configs, Transformer
import importlib

models_config = {
"llama3": llama3_configs,
}
import os
import pkgutil
from typing import Dict, Set

model_name_to_cls = {"llama3": Transformer}
import torchtitan.models as models
from torchtitan.model_spec import ModelSpec

model_name_to_tokenizer = {
"llama3": "tiktoken",
}

_model_specs_path: Set[str] = set()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit

Suggested change
_model_specs_path: Set[str] = set()
_model_spec_paths: Set[str] = set()



def _load_module(path: str):
path = os.path.expanduser(path)

# 1. Check if path is an existing file or directory path.
if os.path.exists(path):
if os.path.isdir(path):
init_file = os.path.join(path, "__init__.py")
if os.path.isfile(init_file):
return _load_module_from_init(path)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sorry if this is noob question: don't we need to put this function before _load_module?

Comment on lines +24 to +28
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe let's do the assert first to avoid a nested if for better readability?


raise ImportError(
f"Directory '{path}' is not a Python package because it does not "
"contain an __init__.py file."
)
else:
raise ImportError(f"Path '{path}' is not a directory.")

# 2. If not a valid path, assume it's a dotted module name.
return importlib.import_module(path)


def _load_module_from_init(path: str):
module_name = os.path.basename(os.path.normpath(path))
init_file = os.path.join(path, "__init__.py")

spec = importlib.util.spec_from_file_location(module_name, init_file)
if spec is None:
raise ImportError(f"Could not create spec from '{init_file}'")
Comment on lines +46 to +47
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto: directly assert?


module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module)
return module


for _, name, _ in pkgutil.iter_modules(models.__path__):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why do we have a global for loop here?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is for importing the models in models folder, i.e. "official" models in torchtitan

full_module_name = f"{models.__name__}.{name}"
_model_specs_path.add(full_module_name)
# model_module = importlib.import_module(full_module_name)
# load_spec_from_module(model_module)


def add_model_spec_path(path: str):
global _model_specs_path
_model_specs_path.add(path)


def build_model_specs() -> Dict[str, ModelSpec]:
"""
Load all model specs from the `models` package.
"""
global _model_specs_path
model_specs = {}
for path in _model_specs_path:
module = _load_module(path)
model_spec = getattr(module, "model_spec", None)
if model_spec is not None:
model_specs[model_spec.name] = model_spec
# We would like to just use `model_spec` but current torchtitan parallelize
# functions depend on ModelArgs and can cause circular imports.
# As a result, we have to use `build_model_spec` as a workaround.
build_model_spec = getattr(module, "build_model_spec", None)
if build_model_spec:
model_spec = build_model_spec()
model_specs[model_spec.name] = model_spec

return model_specs


__all__ = [add_model_spec_path, build_model_specs]
17 changes: 16 additions & 1 deletion torchtitan/models/llama/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@
#
# Copyright (c) Meta Platforms, Inc. All Rights Reserved.

from torchtitan.model_spec import ModelSpec
from torchtitan.models.llama.model import ModelArgs, Transformer

__all__ = ["Transformer"]

llama3_configs = {
"debugmodel": ModelArgs(dim=256, n_layers=8, n_heads=16, rope_theta=500000),
Expand Down Expand Up @@ -40,3 +40,18 @@
rope_theta=500000,
),
}


def build_model_spec() -> ModelSpec:
# Avoid circular import
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suggest we do restructuring of the repo to make it more logical. E.g. below is an example, can be renamed
torchtitan/examples includes llama, llama_multimodal, etc. as folders
torchtitan/example/llama includes model folder and parallelize_llama folder/file.

The original parallelisms folder can stay there including parallel_dims.py and common utils.

You'd avoid circular import, if we put ModelSpec in the llama folder, instead of in the model folder here.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please add unit tests, and documentation with examples in docs/extension.md

from torchtitan.parallelisms.parallelize_llama import parallelize_llama
from torchtitan.parallelisms.pipeline_llama import pipeline_llama

return ModelSpec(
name="llama3",
cls=Transformer,
config=llama3_configs,
tokenizer="tiktoken",
parallelize_fn=parallelize_llama,
pipelining_fn=pipeline_llama,
)
5 changes: 3 additions & 2 deletions torchtitan/models/llama/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,12 @@
import torch
import torch.nn.functional as F
from torch import nn
from torchtitan.model_spec import BaseModelArgs, ModelProtocol
from torchtitan.models.norms import build_norm


@dataclass
class ModelArgs:
class ModelArgs(BaseModelArgs):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Down the road we will have many models, like MM model. Do we want all model args to inherit this? Currently we use different model args for different model arch.

dim: int = 4096
n_layers: int = 32
n_heads: int = 32
Expand Down Expand Up @@ -258,7 +259,7 @@ def init_weights(self, init_std: float):
nn.init.trunc_normal_(linear.weight, mean=0.0, std=init_std)


class TransformerBlock(nn.Module):
class TransformerBlock(nn.Module, ModelProtocol):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should be on Transformer, not TransformerBlock

"""
TransformerBlock Module

Expand Down
21 changes: 9 additions & 12 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,9 @@
from torchtitan.float8 import Float8Handler
from torchtitan.logging import init_logger, logger
from torchtitan.metrics import build_device_memory_monitor, build_metric_logger
from torchtitan.models import model_name_to_cls, model_name_to_tokenizer, models_config
from torchtitan.models import build_model_specs
from torchtitan.optimizer import build_lr_schedulers, build_optimizers
from torchtitan.parallelisms import (
models_parallelize_fns,
models_pipelining_fns,
ParallelDims,
)
from torchtitan.parallelisms import ParallelDims
from torchtitan.profiling import maybe_enable_memory_snapshot, maybe_enable_profiling
from torchtitan.utils import device_module, device_type

Expand Down Expand Up @@ -80,9 +76,10 @@ def main(job_config: JobConfig):
world_mesh, device, job_config.training.seed, job_config.training.deterministic
)
model_name = job_config.model.name
model_spec = build_model_specs()[model_name]

# build tokenizer
tokenizer_type = model_name_to_tokenizer[model_name]
tokenizer_type = model_spec.tokenizer
tokenizer = build_tokenizer(tokenizer_type, job_config.model.tokenizer_path)
# build dataloader
data_loader = build_hf_data_loader(
Expand All @@ -96,8 +93,8 @@ def main(job_config: JobConfig):
)

# build model (using meta init)
model_cls = model_name_to_cls[model_name]
model_config = models_config[model_name][job_config.model.flavor]
model_cls = model_spec.cls
model_config = model_spec.config[job_config.model.flavor]
# set the model configs from training inputs:
# 1. norm type to decide which norm layer to use
# 2. vocab size from tokenizer
Expand Down Expand Up @@ -151,7 +148,7 @@ def loss_fn(pred, labels):
# apply parallelisms and initialization
if parallel_dims.pp_enabled:
# apply PT-D Pipeline Parallel
pp_schedule, model_parts = models_pipelining_fns[model_name](
pp_schedule, model_parts = model_spec.pipelining_fn(
model, pp_mesh, parallel_dims, job_config, device, model_config, loss_fn
)
# when PP is enabled, `model` obj is no longer used after this point, model_parts is used instead
Expand All @@ -162,14 +159,14 @@ def loss_fn(pred, labels):
# optimizer, and checkpointing
for m in model_parts:
# apply SPMD-style PT-D techniques
models_parallelize_fns[model_name](m, world_mesh, parallel_dims, job_config)
model_spec.parallelize_fn(m, world_mesh, parallel_dims, job_config)
m.to_empty(device=init_device)
with torch.no_grad():
m.init_weights(buffer_device=buffer_device)
m.train()
else:
# apply PT-D Tensor Parallel, activation checkpointing, torch.compile, Data Parallel
models_parallelize_fns[model_name](model, world_mesh, parallel_dims, job_config)
model_spec.parallelize_fn(model, world_mesh, parallel_dims, job_config)
model.to_empty(device=init_device)
with torch.no_grad():
model.init_weights(buffer_device=buffer_device)
Expand Down
Loading