Skip to content

Commit

Permalink
Support for optimizers which don't have "fused" parameter such as gro…
Browse files Browse the repository at this point in the history
…kadamw and 8bit bnb (#1744)

Co-authored-by: Sebastian Raschka <mail@sebastianraschka.com>
  • Loading branch information
mtasic85 and rasbt authored Sep 26, 2024
1 parent 87d694d commit b4b8dfc
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 3 deletions.
30 changes: 27 additions & 3 deletions litgpt/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -558,13 +558,37 @@ def instantiate_bnb_optimizer(optimizer, model_parameters):


def instantiate_torch_optimizer(optimizer, model_parameters, **kwargs):
# Special care taken where some optimizers do not have some parameters referenced in some of the code, for example "fused" in the pretrain.py script:
# bnb.optim.AdamW8bit
# grokadamw.GrokAdamW
# torch.optim.RMSprop

if isinstance(optimizer, str):
optimizer_cls = getattr(torch.optim, optimizer)
if "." in optimizer:
class_module, class_name = optimizer.rsplit(".", 1)
else:
class_module, class_name = "torch.optim", optimizer

module = __import__(class_module, fromlist=[class_name])
optimizer_cls = getattr(module, class_name)

valid_params = set(inspect.signature(optimizer_cls).parameters)
kwargs = {key: value for key, value in dict(kwargs).items() if key in valid_params}
optimizer = optimizer_cls(model_parameters, **kwargs)
else:
optimizer = dict(optimizer) # copy
elif isinstance(optimizer, dict):
optimizer = dict(optimizer)
class_module, class_name = optimizer["class_path"].rsplit(".", 1)
module = __import__(class_module, fromlist=[class_name])
optimizer_cls = getattr(module, class_name)

valid_params = set(inspect.signature(optimizer_cls).parameters)
kwargs = {key: value for key, value in dict(kwargs).items() if key in valid_params}

optimizer["init_args"].update(kwargs)
optimizer = instantiate_class(model_parameters, optimizer)
else:
raise ValueError(f'Unrecognized "optimizer" value: {optimizer}')

return optimizer


Expand Down
21 changes: 21 additions & 0 deletions tests/test_pretrain.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,27 @@
from tests.conftest import RunIf


@RunIf(min_cuda_gpus=1, standalone=True)
@mock.patch("litgpt.pretrain.save_hyperparameters")
def test_optimizer_args(_, tmp_path):
model_config = Config(block_size=2, n_layer=2, n_embd=4, n_head=2, padded_vocab_size=8)

dataset = torch.tensor([[0, 1, 2], [3, 4, 5], [0, 1, 2]])
dataloader = DataLoader(dataset)
pretrain.get_dataloaders = Mock(return_value=(dataloader, dataloader))

for i in ("AdamW", "SGD", "RMSprop"):
pretrain.setup(
"pythia-14m",
devices=1,
optimizer="RMSprop",
model_config=model_config,
out_dir=tmp_path,
train=TrainArgs(global_batch_size=2, max_tokens=16, save_interval=1, micro_batch_size=1, max_norm=1.0),
eval=EvalArgs(interval=1, max_iters=1, final_validation=False),
)


@RunIf(min_cuda_gpus=2, standalone=True)
# Set CUDA_VISIBLE_DEVICES for FSDP hybrid-shard, if fewer GPUs are used than are available
@mock.patch.dict(os.environ, {"CUDA_VISIBLE_DEVICES": "0,1"})
Expand Down

0 comments on commit b4b8dfc

Please sign in to comment.