Skip to content

Commit

Permalink
add RAdamScheduleFree optimizer (#35313)
Browse files Browse the repository at this point in the history
* add RAdamScheduleFree optimizer

* revert schedulefree version to the minimum requirement

* refine is_schedulefree_available so that it can take min_version

* refine documents

---------

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
  • Loading branch information
nhamanasu and ArthurZucker authored Feb 12, 2025
1 parent f5fff67 commit 377d8e2
Show file tree
Hide file tree
Showing 5 changed files with 67 additions and 20 deletions.
15 changes: 10 additions & 5 deletions docs/source/en/trainer.md
Original file line number Diff line number Diff line change
Expand Up @@ -542,13 +542,17 @@ trainer = Trainer(
trainer.train()
```

This script demonstrates how to fine-tune the `google/gemma-2b` model on the IMDB dataset using the GrokAdamW optimizer. The `TrainingArguments` are configured to use GrokAdamW, and the dataset is passed to the `Trainer` for training.
This script demonstrates how to fine-tune the [google/gemma-2b](https://huggingface.co/google/gemma-2b) model on the IMDB dataset using the GrokAdamW optimizer. The `TrainingArguments` are configured to use GrokAdamW, and the dataset is passed to the `Trainer` for training.

### Schedule Free Optimizer
### Schedule-Free Optimizer

The Schedule-Free optimizers have been introduced in [The Road Less Scheduled](https://hf.co/papers/2405.15682).
Supported optimizers for Schedule-Free are `schedule_free_radam`, `schedule_free_adamw` and `schedule_free_sgd`. First install schedulefree from pypi `pip install schedulefree`.

The Schedule Free optimizers have been introduced in [The Road Less Scheduled](https://hf.co/papers/2405.15682).
Schedule-Free learning replaces the momentum of the base optimizer with a combination of averaging and interpolation, to completely remove the need to anneal the learning rate with a traditional schedule.
Supported optimizers for SFO are `"schedule_free_adamw"` and `"schedule_free_sgd"`. First install schedulefree from pypi `pip install schedulefree`.
Additionally, neither `warmup_steps` nor `warmup_ratio` parameters are required when using `schedule_free_radam`.

By default, we recommend setting `lr_scheduler_type="constant"` in the `TrainingArguments`. Setting other `lr_scheduler_type` would also work, but combining Schedule-Free with other learning rate schedules is not well-studied both in research and in practice, as it may affect the optimizer's intended behavior and performance guarantees.

Below is a simple script to demonstrate how to fine-tune [google/gemma-2b](https://huggingface.co/google/gemma-2b) on IMDB dataset in full precision:

Expand All @@ -564,7 +568,8 @@ args = TrainingArguments(
output_dir="./test-schedulefree",
max_steps=1000,
per_device_train_batch_size=4,
optim="schedule_free_adamw",
optim="schedule_free_radam",
lr_scheduler_type="constant",
gradient_checkpointing=True,
logging_strategy="steps",
logging_steps=1,
Expand Down
24 changes: 20 additions & 4 deletions src/transformers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1644,28 +1644,44 @@ def optimizer_hook(param):
raise ValueError("Invalid optimizer")
optimizer_kwargs.update(adam_kwargs)
elif args.optim in [
OptimizerNames.SCHEDULE_FREE_RADAM,
OptimizerNames.SCHEDULE_FREE_ADAMW,
OptimizerNames.SCHEDULE_FREE_SGD,
]:
if not is_schedulefree_available():
raise ImportError(
"You need to install `schedulefree` in order to use schedulefree optimizers"
" install it with `pip install schedulefree`"
"You need to install `schedulefree` in order to use schedulefree optimizers. "
"Install it with `pip install schedulefree.`"
)
if not is_accelerate_available("0.30.0"):
raise ImportError("You need to have `accelerate>=0.30.0` to be able to use schedulefree optimizers")
from schedulefree import AdamWScheduleFree, SGDScheduleFree

additional_optim_kwargs = {}
if args.optim == OptimizerNames.SCHEDULE_FREE_ADAMW:
require_warmup = True

if args.optim == OptimizerNames.SCHEDULE_FREE_RADAM:
if not is_schedulefree_available("1.4.0"):
raise ImportError(
"You need to install `schedulefree>=1.4.0` in order to use RAdamScheduleFree optimizer. "
"Install it with `pip install schedulefree.`"
)
from schedulefree import RAdamScheduleFree

optimizer_cls = RAdamScheduleFree
additional_optim_kwargs = adam_kwargs
require_warmup = False
elif args.optim == OptimizerNames.SCHEDULE_FREE_ADAMW:
optimizer_cls = AdamWScheduleFree
additional_optim_kwargs = adam_kwargs
elif args.optim == OptimizerNames.SCHEDULE_FREE_SGD:
optimizer_cls = SGDScheduleFree
else:
raise ValueError("Invalid schedulefree optimizer")

additional_optim_kwargs["weight_decay"] = args.weight_decay
additional_optim_kwargs["warmup_steps"] = args.warmup_steps
if require_warmup:
additional_optim_kwargs["warmup_steps"] = args.warmup_steps
additional_optim_kwargs.update(
{
"weight_lr_power": float(optim_args.get("weight_lr_power", 2.0)),
Expand Down
1 change: 1 addition & 0 deletions src/transformers/training_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,7 @@ class OptimizerNames(ExplicitEnum):
LOMO = "lomo"
ADALOMO = "adalomo"
GROKADAMW = "grokadamw"
SCHEDULE_FREE_RADAM = "schedule_free_radam"
SCHEDULE_FREE_ADAMW = "schedule_free_adamw"
SCHEDULE_FREE_SGD = "schedule_free_sgd"

Expand Down
7 changes: 4 additions & 3 deletions src/transformers/utils/import_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ def _is_package_available(pkg_name: str, return_version: bool = False) -> Union[
TORCH_FX_REQUIRED_VERSION = version.parse("1.10")

ACCELERATE_MIN_VERSION = "0.26.0"
SCHEDULEFREE_MIN_VERSION = "1.2.6"
FSDP_MIN_VERSION = "1.12.0"
GGUF_MIN_VERSION = "0.10.0"
XLA_FSDPV2_MIN_VERSION = "2.2.0"
Expand All @@ -108,7 +109,7 @@ def _is_package_available(pkg_name: str, return_version: bool = False) -> Union[
_galore_torch_available = _is_package_available("galore_torch")
_lomo_available = _is_package_available("lomo_optim")
_grokadamw_available = _is_package_available("grokadamw")
_schedulefree_available = _is_package_available("schedulefree")
_schedulefree_available, _schedulefree_version = _is_package_available("schedulefree", return_version=True)
# `importlib.metadata.version` doesn't work with `bs4` but `beautifulsoup4`. For `importlib.util.find_spec`, reversed.
_bs4_available = importlib.util.find_spec("bs4") is not None
_coloredlogs_available = _is_package_available("coloredlogs")
Expand Down Expand Up @@ -410,8 +411,8 @@ def is_grokadamw_available():
return _grokadamw_available


def is_schedulefree_available():
return _schedulefree_available
def is_schedulefree_available(min_version: str = SCHEDULEFREE_MIN_VERSION):
return _schedulefree_available and version.parse(_schedulefree_version) >= version.parse(min_version)


def is_pyctcdecode_available():
Expand Down
40 changes: 32 additions & 8 deletions tests/trainer/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1865,14 +1865,38 @@ def test_schedulefree_adam(self):
x = torch.randint(0, 100, (128,))
train_dataset = RepeatDataset(x)

# Trainer without inf/nan filter
args = TrainingArguments(
self.get_auto_remove_tmp_dir(),
learning_rate=1e-9,
logging_steps=5,
optim="schedule_free_adamw",
)
trainer = Trainer(tiny_llama, args, train_dataset=train_dataset)
with tempfile.TemporaryDirectory() as tmpdir:
# Trainer without inf/nan filter
args = TrainingArguments(
tmpdir,
learning_rate=1e-9,
logging_steps=5,
optim="schedule_free_adamw",
lr_scheduler_type="constant",
)
trainer = Trainer(tiny_llama, args, train_dataset=train_dataset)

# Check this works
_ = trainer.train()

@require_schedulefree
@require_torch_gpu
def test_schedulefree_radam(self):
config = LlamaConfig(vocab_size=100, hidden_size=32, num_hidden_layers=3, num_attention_heads=4)
tiny_llama = LlamaForCausalLM(config)
x = torch.randint(0, 100, (128,))
train_dataset = RepeatDataset(x)

with tempfile.TemporaryDirectory() as tmpdir:
# Trainer without inf/nan filter
args = TrainingArguments(
tmpdir,
learning_rate=1e-9,
logging_steps=5,
lr_scheduler_type="constant",
optim="schedule_free_radam",
)
trainer = Trainer(tiny_llama, args, train_dataset=train_dataset)

# Check this works
_ = trainer.train()
Expand Down

0 comments on commit 377d8e2

Please sign in to comment.