Skip to content

Commit

Permalink
Fix for LSP
Browse files Browse the repository at this point in the history
  • Loading branch information
jaysonfrancis committed Jan 1, 2025
1 parent 8bcda62 commit f7c669b
Show file tree
Hide file tree
Showing 2 changed files with 72 additions and 84 deletions.
1 change: 1 addition & 0 deletions .ci/docker/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,4 @@ tiktoken
blobfile
tabulate
wandb
tyro
155 changes: 71 additions & 84 deletions torchtitan/config_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,8 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from dataclasses import dataclass, field, fields, is_dataclass
from pprint import pformat
from typing import Any, Dict, List, Optional, Type, Union
from dataclasses import dataclass, field, fields, is_dataclass, asdict
from typing import Any, Dict, List, Optional, Union

import torch
import tyro
Expand Down Expand Up @@ -409,23 +408,6 @@ class MemoryEstimation:


@dataclass
class _JobConfig:
job: Job = field(default_factory=Job)
profiling: Profiling = field(default_factory=Profiling)
metrics: Metrics = field(default_factory=Metrics)
model: Model = field(default_factory=Model)
optimizer: Optimizer = field(default_factory=Optimizer)
training: Training = field(default_factory=Training)
experimental: Experimental = field(default_factory=Experimental)
checkpoint: Checkpoint = field(default_factory=Checkpoint)
activation_checkpoint: ActivationCheckpoint = field(
default_factory=ActivationCheckpoint
)
float8: Float8 = field(default_factory=Float8)
comm: Comm = field(default_factory=Comm)
memory_estimation: MemoryEstimation = field(default_factory=MemoryEstimation)


class JobConfig:
"""
A helper class to manage the train configuration.
Expand All @@ -448,90 +430,95 @@ class JobConfig:
in the toml file
"""

def __init__(self, config_class=_JobConfig):
self.config_class = config_class
self.defaults = config_class()
self.config = None # set during parse_args()

def __getattr__(self, item):
if self.config is not None and hasattr(self.config, item):
return getattr(self.config, item)
raise AttributeError(f"'JobConfig' object has no attribute '{item}'")

def __repr__(self):
return pformat(self.config, sort_dicts=False)
job: Job = field(default_factory=Job)
profiling: Profiling = field(default_factory=Profiling)
metrics: Metrics = field(default_factory=Metrics)
model: Model = field(default_factory=Model)
optimizer: Optimizer = field(default_factory=Optimizer)
training: Training = field(default_factory=Training)
experimental: Experimental = field(default_factory=Experimental)
checkpoint: Checkpoint = field(default_factory=Checkpoint)
activation_checkpoint: ActivationCheckpoint = field(
default_factory=ActivationCheckpoint
)
float8: Float8 = field(default_factory=Float8)
comm: Comm = field(default_factory=Comm)
memory_estimation: MemoryEstimation = field(default_factory=MemoryEstimation)

def __str__(self):
return repr(self)
def to_dict(self) -> Dict[str, Any]:
return asdict(self)

def parse_args(self):
cli_config = tyro.cli(self.config_class)
@classmethod
def parse_args(cls) -> "JobConfig":
"""
Parse CLI arguments, optionally load from a TOML file,
merge with defaults, and return a JobConfig instance.
"""
cli_config = tyro.cli(cls)
config_file = cli_config.job.config_file

if config_file:
logger.info(f"Loading configuration from {config_file}")
toml_config = self._dict_to_dataclass(
self.config_class, self.load_toml(config_file)
)
base_config = self.merge_with_defaults(toml_config, self.defaults)
toml_data = cls._load_toml(config_file)
toml_config = cls._dict_to_dataclass(cls, toml_data)

# TOML > deafults
merged_config = cls._merge_with_defaults(cli_config, toml_config)

# cmdline > TOML > defaults
final_config = tyro.cli(cls, default=merged_config)
else:
base_config = self.defaults
final_config = cli_config

self.config = tyro.cli(
self.config_class, default=base_config
) # override cli args
self._validate_config(self.config)
return self.config
final_config._validate_config()
return final_config

def load_toml(self, file_path: str) -> Dict[str, Any]:
"""Load configuration from a TOML file."""
@staticmethod
def _load_toml(file_path: str) -> Dict[str, Any]:
try:
with open(file_path, "rb") as f:
return tomllib.load(f)
except (FileNotFoundError, tomllib.TOMLDecodeError) as e:
logger.exception(f"Error while loading the configuration file: {file_path}")
logger.exception(f"Error details: {str(e)}")
logger.exception(f"Error while loading config file: {file_path}")
raise e

def _dict_to_dataclass(self, cls: Type, data: Dict[str, Any]):
"""Convert a dictionary to a nested dataclass instance."""
if not is_dataclass(cls):
@classmethod
def _dict_to_dataclass(cls, config_class, data: Dict[str, Any]) -> Any:
"""Recursively convert dictionaries to nested dataclasses."""
if not is_dataclass(config_class):
return data

def convert_value(field_type, value):
# Recursively convert nested dict into dataclass as needed
if is_dataclass(field_type) and isinstance(value, dict):
return self._dict_to_dataclass(field_type, value)
return value

field_values = {
f.name: convert_value(f.type, data[f.name])
for f in fields(cls)
if f.name in data
}
return cls(**field_values)

def _merge_with_defaults(self, source: Any, defaults: Any):
"""Recursively merge two dataclass instances."""
kwargs = {}
for f in fields(config_class):
if f.name in data:
value = data[f.name]
# If target field is also a dataclass and value is a dict, recurse
if is_dataclass(f.type) and isinstance(value, dict):
kwargs[f.name] = cls._dict_to_dataclass(f.type, value)
else:
kwargs[f.name] = value
return config_class(**kwargs)

@classmethod
def _merge_with_defaults(cls, source: "JobConfig", defaults: "JobConfig") -> "JobConfig":
"""Recursively merge two dataclass instances (source overrides defaults)."""
if not is_dataclass(source) or not is_dataclass(defaults):
return source or defaults

result = {}
merged_kwargs = {}
for f in fields(source):
value_a = getattr(source, f.name)
value_b = getattr(defaults, f.name)
result[f.name] = self._merge_with_defaults(value_a, value_b)
return source.__class__(**result)

def _validate_config(self, config) -> None:
# TODO: Add more mandatory validations
assert config.model.name, "Model name is required"
assert config.model.flavor, "Model flavor is required"
assert config.model.tokenizer_path, "Model tokenizer path is required"
source_val = getattr(source, f.name)
default_val = getattr(defaults, f.name)
# If both are dataclasses, merge recursively
if is_dataclass(source_val) and is_dataclass(default_val):
merged_kwargs[f.name] = cls._merge_with_defaults(source_val, default_val)
else:
merged_kwargs[f.name] = source_val if source_val is not None else default_val

return type(source)(**merged_kwargs)

if __name__ == "__main__":
# Example
config = JobConfig()
config.parse_args()
print(config)
def _validate_config(self) -> None:
# TODO: Add more mandatory validations
assert self.model.name, "Model name is required"
assert self.model.flavor, "Model flavor is required"
assert self.model.tokenizer_path, "Model tokenizer path is required"

0 comments on commit f7c669b

Please sign in to comment.