Skip to content

Commit

Permalink
Fix capturing hparams for loggers that don't support serializing non-…
Browse files Browse the repository at this point in the history
…primitives (#1281)
  • Loading branch information
awaelchli authored Apr 12, 2024
1 parent ca07e5e commit 3516bea
Show file tree
Hide file tree
Showing 4 changed files with 44 additions and 3 deletions.
3 changes: 2 additions & 1 deletion extensions/thunder/pretrain.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from litgpt.utils import (
CLI,
CycleIterator,
capture_hparams,
choose_logger,
chunked_cross_entropy,
copy_config_files,
Expand Down Expand Up @@ -97,7 +98,7 @@ def setup(
executors: If using Thunder, the executors to enable.
strategy: If desired, the strategy to use.
"""
hparams = locals()
hparams = capture_hparams()
data = TinyLlama() if data is None else data
if model_config is not None and model_name is not None:
raise ValueError("Only one of `model_name` or `model_config` can be set.")
Expand Down
3 changes: 2 additions & 1 deletion litgpt/pretrain.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from litgpt.utils import (
CLI,
CycleIterator,
capture_hparams,
choose_logger,
chunked_cross_entropy,
copy_config_files,
Expand Down Expand Up @@ -87,7 +88,7 @@ def setup(
logger_name: The name of the logger to send metrics to.
seed: The random seed to use for reproducibility.
"""
hparams = locals()
hparams = capture_hparams()
data = TinyLlama() if data is None else data
if model_config is not None and model_name is not None:
raise ValueError("Only one of `model_name` or `model_config` can be set.")
Expand Down
18 changes: 17 additions & 1 deletion litgpt/utils.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file.

"""Utility functions for training and inference."""
import inspect
import math
import pickle
import shutil
import sys
from dataclasses import asdict
from dataclasses import asdict, is_dataclass
from io import BytesIO
from pathlib import Path
from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Literal, Mapping, Optional, TypeVar, Union
Expand Down Expand Up @@ -404,6 +405,21 @@ def CLI(*args: Any, **kwargs: Any) -> Any:
return CLI(*args, **kwargs)


def capture_hparams() -> Dict[str, Any]:
"""Captures the local variables ('hyperparameters') from where this function gets called."""
caller_frame = inspect.currentframe().f_back
locals_of_caller = caller_frame.f_locals
hparams = {}
for name, value in locals_of_caller.items():
if value is None or isinstance(value, (int, float, str, bool, Path)):
hparams[name] = value
elif is_dataclass(value):
hparams[name] = asdict(value)
else:
hparams[name] = str(value)
return hparams


def save_hyperparameters(function: callable, checkpoint_dir: Path) -> None:
"""Captures the CLI parameters passed to `function` without running `function` and saves them to the checkpoint."""
from jsonargparse import capture_parser
Expand Down
23 changes: 23 additions & 0 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file.
from dataclasses import asdict

import os
from contextlib import redirect_stderr
Expand All @@ -18,9 +19,11 @@
from lightning_utilities.core.imports import RequirementCache

from litgpt import GPT
from litgpt.args import TrainArgs
from litgpt.utils import (
CLI,
CycleIterator,
capture_hparams,
check_valid_checkpoint_dir,
choose_logger,
chunked_cross_entropy,
Expand Down Expand Up @@ -219,6 +222,26 @@ def test_copy_config_files(fake_checkpoint_dir, tmp_path):
assert expected.issubset(contents)


def test_capture_hparams():
integer = 1
string = "string"
boolean = True
none = None
path = Path("/path")
dataclass = TrainArgs()
other = torch.nn.Linear(1, 1)
hparams = capture_hparams()
assert hparams == {
"integer": integer,
"string": string,
"boolean": boolean,
"none": none,
"path": path,
"dataclass": asdict(dataclass),
"other": str(other),
}


def _test_function(out_dir: Path, foo: bool = False, bar: int = 1):
save_hyperparameters(_test_function, out_dir)

Expand Down

0 comments on commit 3516bea

Please sign in to comment.