Skip to content

Commit

Permalink
Merge pull request #324 from kozistr/feature/ranger25
Browse files Browse the repository at this point in the history
[Refactor] flexible and consistent `optimizer` parameters for `Lookahead`, `TRAC`, and `OrthoGrad` optimizers
  • Loading branch information
kozistr authored Jan 18, 2025
2 parents 5baa713 + 87e1a60 commit a9fb8a2
Show file tree
Hide file tree
Showing 28 changed files with 496 additions and 130 deletions.
5 changes: 3 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@

## The reasons why you use `pytorch-optimizer`.

* Wide range of supported optimizers. Currently, **89 optimizers (+ `bitsandbytes`, `qgalore`, `torchao`)**, **16 lr schedulers**, and **13 loss functions** are supported!
* Including many variants such as `Cautious`, `AdamD`, `Gradient Centrailiaztion`
* Wide range of supported optimizers. Currently, **90 optimizers (+ `bitsandbytes`, `qgalore`, `torchao`)**, **16 lr schedulers**, and **13 loss functions** are supported!
* Including many variants such as `ADOPT`, `Cautious`, `AdamD`, `StableAdamW`, and `Gradient Centrailiaztion`
* Easy to use, clean, and tested codes
* Active maintenance
* Somewhat a bit more optimized compared to the original implementation
Expand Down Expand Up @@ -198,6 +198,7 @@ get_supported_optimizers(['adam*', 'ranger*'])
| Grams | *Gradient Descent with Adaptive Momentum Scaling* | | <https://arxiv.org/abs/2412.17107> | [cite](https://ui.adsabs.harvard.edu/abs/2024arXiv241217107C/exportcitation) |
| OrthoGrad | *Grokking at the Edge of Numerical Stability* | [github](https://github.com/LucasPrietoAl/grokking-at-the-edge-of-numerical-stability) | <https://arxiv.org/abs/2501.04697> | [cite](https://github.com/LucasPrietoAl/grokking-at-the-edge-of-numerical-stability?tab=readme-ov-file#citation) |
| Adam-ATAN2 | *Scaling Exponents Across Parameterizations and Optimizers* | | <https://arxiv.org/abs/2407.05872> | [cite](https://ui.adsabs.harvard.edu/abs/2024arXiv240705872E/exportcitation) |
| SPAM | *Spike-Aware Adam with Momentum Reset for Stable LLM Training* | [github](https://github.com/TianjinYellow/SPAM-Optimizer) | <https://arxiv.org/abs/2501.06842> | [cite](https://ui.adsabs.harvard.edu/abs/2025arXiv250106842H/exportcitation) |

## Supported LR Scheduler

Expand Down
12 changes: 12 additions & 0 deletions docs/changelogs/v3.3.4.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
### Change Log

### Feature

* Support `OrthoGrad` feature for `create_optimizer()`. (#324)
* Enhanced flexibility for the `optimizer` parameter in `Lookahead`, `TRAC`, and `OrthoGrad` optimizers. (#324)
* Now supports both torch.optim.Optimizer instances and classes
* You can now use `Lookahead` optimizer in two ways.
* `Lookahead(AdamW(model.parameters(), lr=1e-3), k=5, alpha=0.5)`
* `Lookahead(AdamW, k=5, alpha=0.5, params=model.parameters())`
* Implement `SPAM` optimizer. (#324)
* [Spike-Aware Adam with Momentum Reset for Stable LLM Training](https://arxiv.org/abs/2501.06842)
5 changes: 3 additions & 2 deletions docs/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@

## The reasons why you use `pytorch-optimizer`.

* Wide range of supported optimizers. Currently, **89 optimizers (+ `bitsandbytes`, `qgalore`, `torchao`)**, **16 lr schedulers**, and **13 loss functions** are supported!
* Including many variants such as `Cautious`, `AdamD`, `Gradient Centrailiaztion`
* Wide range of supported optimizers. Currently, **90 optimizers (+ `bitsandbytes`, `qgalore`, `torchao`)**, **16 lr schedulers**, and **13 loss functions** are supported!
* Including many variants such as `ADOPT`, `Cautious`, `AdamD`, `StableAdamW`, and `Gradient Centrailiaztion`
* Easy to use, clean, and tested codes
* Active maintenance
* Somewhat a bit more optimized compared to the original implementation
Expand Down Expand Up @@ -198,6 +198,7 @@ get_supported_optimizers(['adam*', 'ranger*'])
| Grams | *Gradient Descent with Adaptive Momentum Scaling* | | <https://arxiv.org/abs/2412.17107> | [cite](https://ui.adsabs.harvard.edu/abs/2024arXiv241217107C/exportcitation) |
| OrthoGrad | *Grokking at the Edge of Numerical Stability* | [github](https://github.com/LucasPrietoAl/grokking-at-the-edge-of-numerical-stability) | <https://arxiv.org/abs/2501.04697> | [cite](https://github.com/LucasPrietoAl/grokking-at-the-edge-of-numerical-stability?tab=readme-ov-file#citation) |
| Adam-ATAN2 | *Scaling Exponents Across Parameterizations and Optimizers* | | <https://arxiv.org/abs/2407.05872> | [cite](https://ui.adsabs.harvard.edu/abs/2024arXiv240705872E/exportcitation) |
| SPAM | *Spike-Aware Adam with Momentum Reset for Stable LLM Training* | [github](https://github.com/TianjinYellow/SPAM-Optimizer) | <https://arxiv.org/abs/2501.06842> | [cite](https://ui.adsabs.harvard.edu/abs/2025arXiv250106842H/exportcitation) |

## Supported LR Scheduler

Expand Down
4 changes: 4 additions & 0 deletions docs/optimizer.md
Original file line number Diff line number Diff line change
Expand Up @@ -368,6 +368,10 @@
:docstring:
:members:

::: pytorch_optimizer.SPAM
:docstring:
:members:

::: pytorch_optimizer.SRMM
:docstring:
:members:
Expand Down
6 changes: 3 additions & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,9 @@ keywords = [
"Kate", "Lamb", "LaProp", "LARS", "Lion", "LOMO", "Lookahead", "MADGRAD", "MARS", "MSVAG", "Muno", "Nero",
"NovoGrad", "OrthoGrad", "PAdam", "PCGrad", "PID", "PNM", "Prodigy", "QHAdam", "QHM", "RAdam", "Ranger",
"Ranger21", "RotoGrad", "SAM", "ScheduleFreeSGD", "ScheduleFreeAdamW", "ScheduleFreeRAdam", "SGDP", "Shampoo",
"ScalableShampoo", "SGDW", "SignSGD", "SM3", "SOAP", "SopihaH", "SRMM", "StableAdamW", "SWATS", "Tiger", "TRAC",
"WSAM", "Yogi", "BCE", "BCEFocal", "Focal", "FocalCosine", "SoftF1", "Dice", "LDAM", "Jaccard", "Bi-Tempered",
"Tversky", "FocalTversky", "LovaszHinge", "bitsandbytes", "WSD", "QGaLore",
"ScalableShampoo", "SGDW", "SignSGD", "SM3", "SOAP", "SopihaH", "SPAM", "SRMM", "StableAdamW", "SWATS", "Tiger",
"TRAC", "WSAM", "Yogi", "BCE", "BCEFocal", "Focal", "FocalCosine", "SoftF1", "Dice", "LDAM", "Jaccard",
"Bi-Tempered", "Tversky", "FocalTversky", "LovaszHinge", "bitsandbytes", "WSD", "QGaLore",
]
classifiers = [
"License :: OSI Approved :: Apache Software License",
Expand Down
1 change: 1 addition & 0 deletions pytorch_optimizer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@
SGDW,
SM3,
SOAP,
SPAM,
SRMM,
SWATS,
TRAC,
Expand Down
7 changes: 4 additions & 3 deletions pytorch_optimizer/base/scheduler.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,17 @@
from abc import ABC, abstractmethod
from typing import List

from torch.optim import Optimizer

from pytorch_optimizer.base.exception import NegativeLRError, NegativeStepError
from pytorch_optimizer.base.types import OPTIMIZER


class BaseLinearWarmupScheduler(ABC):
r"""BaseLinearWarmupScheduler class.
The LR Scheduler class based on this class has linear warmup strategy.
:param optimizer: Optimizer. OPTIMIZER. It will set learning rate to all trainable parameters in optimizer.
:param optimizer: Optimizer. It will set learning rate to all trainable parameters in optimizer.
:param t_max: int. total steps to train.
:param max_lr: float. maximum lr.
:param min_lr: float. minimum lr.
Expand All @@ -20,7 +21,7 @@ class BaseLinearWarmupScheduler(ABC):

def __init__(
self,
optimizer: OPTIMIZER,
optimizer: Optimizer,
t_max: int,
max_lr: float,
min_lr: float = 0.0,
Expand Down
1 change: 1 addition & 0 deletions pytorch_optimizer/base/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
PARAMETERS = Optional[Union[Iterable[Dict], Iterable[torch.Tensor]]]
STATE = Dict
OPTIMIZER = Type[Optimizer]
OPTIMIZER_INSTANCE_OR_CLASS = Union[OPTIMIZER, Optimizer]
SCHEDULER = Type[LRScheduler]

HUTCHINSON_G = Literal['gaussian', 'rademacher']
Expand Down
7 changes: 3 additions & 4 deletions pytorch_optimizer/lr_scheduler/cosine_anealing.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
import math
from typing import List, Optional

from torch.optim import Optimizer
from torch.optim.lr_scheduler import LRScheduler

from pytorch_optimizer.base.types import OPTIMIZER


class CosineAnnealingWarmupRestarts(LRScheduler):
r"""CosineAnnealingWarmupRestarts.
Expand All @@ -21,7 +20,7 @@ class CosineAnnealingWarmupRestarts(LRScheduler):

def __init__(
self,
optimizer: OPTIMIZER,
optimizer: Optimizer,
first_cycle_steps: int,
cycle_mult: float = 1.0,
max_lr: float = 1e-4,
Expand Down Expand Up @@ -53,7 +52,7 @@ def __init__(

self.init_lr()

def init_lr(self):
def init_lr(self) -> None:
self.base_lrs = []
for param_group in self.optimizer.param_groups:
param_group['lr'] = self.min_lr
Expand Down
7 changes: 3 additions & 4 deletions pytorch_optimizer/lr_scheduler/rex.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
from typing import List, Optional

from torch.optim import Optimizer
from torch.optim.lr_scheduler import LRScheduler

from pytorch_optimizer.base.types import OPTIMIZER


class REXScheduler(LRScheduler):
r"""Revisiting Budgeted Training with an Improved Schedule.
Expand All @@ -16,7 +15,7 @@ class REXScheduler(LRScheduler):

def __init__(
self,
optimizer: OPTIMIZER,
optimizer: Optimizer,
total_steps: int,
max_lr: float = 1.0,
min_lr: float = 0.0,
Expand All @@ -35,7 +34,7 @@ def __init__(

self.init_lr()

def init_lr(self):
def init_lr(self) -> None:
self.base_lrs = []
for param_group in self.optimizer.param_groups:
param_group['lr'] = self.min_lr
Expand Down
23 changes: 15 additions & 8 deletions pytorch_optimizer/optimizer/__init__.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
import fnmatch
from importlib.util import find_spec
from typing import Dict, List, Optional, Sequence, Set, Union
from typing import Dict, List, Optional, Sequence, Set, Type, Union

import torch
from torch import nn
from torch.optim import AdamW
from torch.optim import AdamW, Optimizer

from pytorch_optimizer.base.types import OPTIMIZER, PARAMETERS
from pytorch_optimizer.optimizer.a2grad import A2Grad
Expand Down Expand Up @@ -83,6 +83,7 @@
from pytorch_optimizer.optimizer.sm3 import SM3
from pytorch_optimizer.optimizer.soap import SOAP
from pytorch_optimizer.optimizer.sophia import SophiaH
from pytorch_optimizer.optimizer.spam import SPAM
from pytorch_optimizer.optimizer.srmm import SRMM
from pytorch_optimizer.optimizer.swats import SWATS
from pytorch_optimizer.optimizer.tiger import Tiger
Expand Down Expand Up @@ -286,6 +287,7 @@ def load_optimizer(optimizer: str) -> OPTIMIZER:
MARS,
SGDSaI,
Grams,
SPAM,
Ranger25,
]
OPTIMIZERS: Dict[str, OPTIMIZER] = {str(optimizer.__name__).lower(): optimizer for optimizer in OPTIMIZER_LIST}
Expand All @@ -298,31 +300,36 @@ def create_optimizer(
weight_decay: float = 0.0,
wd_ban_list: List[str] = ('bias', 'LayerNorm.bias', 'LayerNorm.weight'),
use_lookahead: bool = False,
use_orthograd: bool = False,
**kwargs,
):
) -> Optimizer:
r"""Build optimizer.
:param model: nn.Module. model.
:param optimizer_name: str. name of optimizer.
:param lr: float. learning rate.
:param weight_decay: float. weight decay.
:param wd_ban_list: List[str]. weight decay ban list by layer.
:param use_lookahead: bool. use lookahead.
:param use_lookahead: bool. use Lookahead.
:param use_orthograd: bool. use OrthoGrad.
"""
optimizer_name = optimizer_name.lower()

parameters = (
get_optimizer_parameters(model, weight_decay, wd_ban_list) if weight_decay > 0.0 else model.parameters()
)

optimizer = load_optimizer(optimizer_name)
optimizer_class: OPTIMIZER = load_optimizer(optimizer_name)

if optimizer_name == 'alig':
optimizer = optimizer(parameters, max_lr=lr, **kwargs)
optimizer = optimizer_class(parameters, max_lr=lr, **kwargs)
elif optimizer_name in {'lomo', 'adalomo', 'adammini'}:
optimizer = optimizer(model, lr=lr, **kwargs)
optimizer = optimizer_class(model, lr=lr, **kwargs)
else:
optimizer = optimizer(parameters, lr=lr, **kwargs)
optimizer = optimizer_class(parameters, lr=lr, **kwargs)

if use_orthograd:
optimizer = OrthoGrad(optimizer, **kwargs)

if use_lookahead:
optimizer = Lookahead(
Expand Down
43 changes: 17 additions & 26 deletions pytorch_optimizer/optimizer/experimental/ranger25.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
class Ranger25(BaseOptimizer):
r"""Mixin' every fancy optimizer hacks.
ADOPT + AdEMAMix + Cautious + StableAdamW + Adam-Atan2
:param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.
:param lr: float. learning rate.
:param betas: BETAS. coefficients used for computing running averages of gradient and the squared hessian trace.
Expand All @@ -19,10 +21,10 @@ class Ranger25(BaseOptimizer):
:param fixed_decay: bool. fix weight decay.
:param alpha: float. usually between 4 and 10 would work well.
:param t_alpha_beta3: Optional[float]. total number of iterations is preferred when needed.
:param n_sma_threshold: number of SMA threshold (recommended is 5).
:param cautious: bool. whether to use the Cautious variant.
:param stable_adamw: bool. whether to use stable AdamW variant.
:param eps: float. term added to the denominator to improve numerical stability.
:param eps: Optional[float]. term added to the denominator to improve numerical stability. when eps is None and
stable_adamw is False, adam-atan2 feature will be used.
"""

def __init__(
Expand All @@ -35,10 +37,9 @@ def __init__(
fixed_decay: bool = False,
alpha: float = 5.0,
t_alpha_beta3: Optional[float] = None,
n_sma_threshold: int = 5,
cautious: bool = True,
stable_adamw: bool = True,
eps: float = 1e-8,
eps: Optional[float] = 1e-8,
**kwargs,
):
self.validate_learning_rate(lr)
Expand All @@ -48,9 +49,8 @@ def __init__(
self.validate_non_negative(weight_decay, 'weight_decay')
self.validate_non_negative(eps, 'eps')

self.n_sma_threshold = n_sma_threshold
self.cautious = cautious
self.stable_adamw = stable_adamw
self.stable_adamw: bool = stable_adamw if isinstance(eps, float) else False

defaults: DEFAULTS = {
'lr': lr,
Expand All @@ -60,7 +60,7 @@ def __init__(
'fixed_decay': fixed_decay,
'alpha': alpha,
't_alpha_beta3': t_alpha_beta3,
'eps': eps,
'eps': eps if (eps is not None) or (eps is None and not stable_adamw) else 1e-8,
}

super().__init__(params, defaults)
Expand Down Expand Up @@ -147,38 +147,29 @@ def step(self, closure: CLOSURE = None) -> LOSS:

exp_avg, exp_avg_sq, exp_avg_slow = state['exp_avg'], state['exp_avg_sq'], state['exp_avg_slow']

de_nom = exp_avg_sq.sqrt().clamp_(min=group['eps'])

normed_grad = grad.div(de_nom).clamp_(-clip, clip)
normed_grad = grad.div(
exp_avg_sq.sqrt().clamp_(min=group['eps'] if group['eps'] is not None else 1e-8)
).clamp_(-clip, clip)

exp_avg.mul_(beta1).add_(normed_grad, alpha=1.0 - beta1)
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1.0 - beta2)
exp_avg_slow.mul_(beta3_t).add_(normed_grad, alpha=1.0 - beta3_t)

de_nom = exp_avg_sq.sqrt().div_(bias_correction2_sq).add_(group['eps'])

update = exp_avg.clone()
if self.cautious:
self.apply_cautious(update, grad)

if self.stable_adamw:
step_size /= self.get_stable_adamw_rms(grad, exp_avg_sq)

step_size, n_sma = self.get_rectify_step_size(
is_rectify=True,
step=group['step'],
lr=step_size,
beta2=beta2,
n_sma_threshold=self.n_sma_threshold,
degenerated_to_sgd=False,
)
update.add_(exp_avg_slow, alpha=alpha_t)

de_nom = exp_avg_sq.sqrt().div_(bias_correction2_sq)

update.add_(exp_avg_slow, alpha=alpha_t).div_(de_nom)
if group['eps'] is not None:
p.addcdiv_(update, de_nom.add_(group['eps']), value=-step_size)
continue

if n_sma >= self.n_sma_threshold:
de_nom = exp_avg_sq.sqrt().add_(group['eps'])
p.addcdiv_(update, de_nom, value=-step_size)
else:
p.add_(update, alpha=-step_size)
p.add_(update.atan2_(de_nom), alpha=-step_size)

return loss
Loading

0 comments on commit a9fb8a2

Please sign in to comment.