diff --git a/litgpt/utils.py b/litgpt/utils.py index 4d620e4163..cd0b3efe5b 100644 --- a/litgpt/utils.py +++ b/litgpt/utils.py @@ -558,13 +558,37 @@ def instantiate_bnb_optimizer(optimizer, model_parameters): def instantiate_torch_optimizer(optimizer, model_parameters, **kwargs): + # Special care taken where some optimizers do not have some parameters referenced in some of the code, for example "fused" in the pretrain.py script: + # bnb.optim.AdamW8bit + # grokadamw.GrokAdamW + # torch.optim.RMSprop + if isinstance(optimizer, str): - optimizer_cls = getattr(torch.optim, optimizer) + if "." in optimizer: + class_module, class_name = optimizer.rsplit(".", 1) + else: + class_module, class_name = "torch.optim", optimizer + + module = __import__(class_module, fromlist=[class_name]) + optimizer_cls = getattr(module, class_name) + + valid_params = set(inspect.signature(optimizer_cls).parameters) + kwargs = {key: value for key, value in dict(kwargs).items() if key in valid_params} optimizer = optimizer_cls(model_parameters, **kwargs) - else: - optimizer = dict(optimizer) # copy + elif isinstance(optimizer, dict): + optimizer = dict(optimizer) + class_module, class_name = optimizer["class_path"].rsplit(".", 1) + module = __import__(class_module, fromlist=[class_name]) + optimizer_cls = getattr(module, class_name) + + valid_params = set(inspect.signature(optimizer_cls).parameters) + kwargs = {key: value for key, value in dict(kwargs).items() if key in valid_params} + optimizer["init_args"].update(kwargs) optimizer = instantiate_class(model_parameters, optimizer) + else: + raise ValueError(f'Unrecognized "optimizer" value: {optimizer}') + return optimizer diff --git a/tests/test_pretrain.py b/tests/test_pretrain.py index 5669a84008..3b28894793 100644 --- a/tests/test_pretrain.py +++ b/tests/test_pretrain.py @@ -18,6 +18,27 @@ from tests.conftest import RunIf +@RunIf(min_cuda_gpus=1, standalone=True) +@mock.patch("litgpt.pretrain.save_hyperparameters") +def test_optimizer_args(_, tmp_path): + model_config = Config(block_size=2, n_layer=2, n_embd=4, n_head=2, padded_vocab_size=8) + + dataset = torch.tensor([[0, 1, 2], [3, 4, 5], [0, 1, 2]]) + dataloader = DataLoader(dataset) + pretrain.get_dataloaders = Mock(return_value=(dataloader, dataloader)) + + for i in ("AdamW", "SGD", "RMSprop"): + pretrain.setup( + "pythia-14m", + devices=1, + optimizer="RMSprop", + model_config=model_config, + out_dir=tmp_path, + train=TrainArgs(global_batch_size=2, max_tokens=16, save_interval=1, micro_batch_size=1, max_norm=1.0), + eval=EvalArgs(interval=1, max_iters=1, final_validation=False), + ) + + @RunIf(min_cuda_gpus=2, standalone=True) # Set CUDA_VISIBLE_DEVICES for FSDP hybrid-shard, if fewer GPUs are used than are available @mock.patch.dict(os.environ, {"CUDA_VISIBLE_DEVICES": "0,1"})