Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Implement ScheduleFreeRAdam, LaProp optimizers and lots of things #304

Merged
merged 14 commits into from
Dec 4, 2024
11 changes: 8 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
@@ -8,9 +8,13 @@
| Status | [![PyPi download](https://static.pepy.tech/badge/pytorch-optimizer)](https://pepy.tech/project/pytorch-optimizer) [![PyPi month download](https://static.pepy.tech/badge/pytorch-optimizer/month)](https://pepy.tech/project/pytorch-optimizer) |
| License | [![apache](https://img.shields.io/badge/License-Apache_2.0-blue.svg)](https://opensource.org/licenses/Apache-2.0) |

**pytorch-optimizer** is optimizer & lr scheduler collections in PyTorch.
I just re-implemented (speed & memory tweaks, plug-ins) the algorithm while based on the original paper. Also, It includes useful and practical optimization ideas.
Currently, **81 optimizers (+ `bitsandbytes`, `qgalore`, `torchao`)**, **16 lr schedulers**, and **13 loss functions** are supported!
## The reasons why you use `pytorch-optimizer`.

1. Wide range of supported optimizers. Currently, **83 optimizers (+ `bitsandbytes`, `qgalore`, `torchao`)**, **16 lr schedulers**, and **13 loss functions** are supported!
2. Including many variants such as `Cautious`, `AdamD`, `Gradient Centrailiaztion`
3. Easy to use, clean, and tested codes
4. Active maintenance
5. Somewhat a bit more optimized compared to the original implementation

Highly inspired by [pytorch-optimizer](https://github.com/jettify/pytorch-optimizer).

@@ -187,6 +191,7 @@ get_supported_optimizers(['adam*', 'ranger*'])
| DeMo | *Decoupled Momentum Optimization* | [github](https://github.com/bloc97/DeMo) | <https://arxiv.org/abs/2411.19870> | [cite](https://ui.adsabs.harvard.edu/abs/2024arXiv241119870P/exportcitation) |
| MicroAdam | *Accurate Adaptive Optimization with Low Space Overhead and Provable Convergence* | [github](https://github.com/IST-DASLab/MicroAdam) | <https://arxiv.org/abs/2405.15593> | [cite](https://github.com/IST-DASLab/MicroAdam?tab=readme-ov-file#citing) |
| Muon | *MomentUm Orthogonalized by Newton-schulz* | [github](https://github.com/KellerJordan/Muon) | <https://x.com/kellerjordan0/status/1842300916864844014> | [cite](https://github.com/KellerJordan/Muon) |
| LaProp | *Separating Momentum and Adaptivity in Adam* | [github](https://github.com/Z-T-WANG/LaProp-Optimizer) | <https://arxiv.org/abs/2002.04839> | [cite](https://github.com/Z-T-WANG/LaProp-Optimizer?tab=readme-ov-file#citation) |

## Supported LR Scheduler

4 changes: 4 additions & 0 deletions docs/changelogs/v3.3.1.md
Original file line number Diff line number Diff line change
@@ -6,3 +6,7 @@
* [Decoupled Momentum Optimization](https://arxiv.org/abs/2411.19870)
* Implement `Muon` optimizer. (#302)
* [MomentUm Orthogonalized by Newton-schulz](https://github.com/KellerJordan/Muon)
* Implement `ScheduleFreeRAdam` optimizer. (#304)
* Implement `LaProp` optimizer. (#304)
* [Separating Momentum and Adaptivity in Adam](https://arxiv.org/abs/2002.04839)
* Support `Cautious` variant to `LaProp`, `AdamP`, `Adopt` optimizers. (#304).
11 changes: 8 additions & 3 deletions docs/index.md
Original file line number Diff line number Diff line change
@@ -8,9 +8,13 @@
| Status | [![PyPi download](https://static.pepy.tech/badge/pytorch-optimizer)](https://pepy.tech/project/pytorch-optimizer) [![PyPi month download](https://static.pepy.tech/badge/pytorch-optimizer/month)](https://pepy.tech/project/pytorch-optimizer) |
| License | [![apache](https://img.shields.io/badge/License-Apache_2.0-blue.svg)](https://opensource.org/licenses/Apache-2.0) |

**pytorch-optimizer** is optimizer & lr scheduler collections in PyTorch.
I just re-implemented (speed & memory tweaks, plug-ins) the algorithm while based on the original paper. Also, It includes useful and practical optimization ideas.
Currently, **81 optimizers (+ `bitsandbytes`, `qgalore`, `torchao`)**, **16 lr schedulers**, and **13 loss functions** are supported!
## The reasons why you use `pytorch-optimizer`.

1. Wide range of supported optimizers. Currently, **83 optimizers (+ `bitsandbytes`, `qgalore`, `torchao`)**, **16 lr schedulers**, and **13 loss functions** are supported!
2. Including many variants such as `Cautious`, `AdamD`, `Gradient Centrailiaztion`
3. Easy to use, clean, and tested codes
4. Active maintenance
5. Somewhat a bit more optimized compared to the original implementation

Highly inspired by [pytorch-optimizer](https://github.com/jettify/pytorch-optimizer).

@@ -187,6 +191,7 @@ get_supported_optimizers(['adam*', 'ranger*'])
| DeMo | *Decoupled Momentum Optimization* | [github](https://github.com/bloc97/DeMo) | <https://arxiv.org/abs/2411.19870> | [cite](https://ui.adsabs.harvard.edu/abs/2024arXiv241119870P/exportcitation) |
| MicroAdam | *Accurate Adaptive Optimization with Low Space Overhead and Provable Convergence* | [github](https://github.com/IST-DASLab/MicroAdam) | <https://arxiv.org/abs/2405.15593> | [cite](https://github.com/IST-DASLab/MicroAdam?tab=readme-ov-file#citing) |
| Muon | *MomentUm Orthogonalized by Newton-schulz* | [github](https://github.com/KellerJordan/Muon) | <https://x.com/kellerjordan0/status/1842300916864844014> | [cite](https://github.com/KellerJordan/Muon) |
| LaProp | *Separating Momentum and Adaptivity in Adam* | [github](https://github.com/Z-T-WANG/LaProp-Optimizer) | <https://arxiv.org/abs/2002.04839> | [cite](https://github.com/Z-T-WANG/LaProp-Optimizer?tab=readme-ov-file#citation) |

## Supported LR Scheduler

8 changes: 8 additions & 0 deletions docs/optimizer.md
Original file line number Diff line number Diff line change
@@ -204,6 +204,10 @@
:docstring:
:members:

::: pytorch_optimizer.LaProp
:docstring:
:members:

::: pytorch_optimizer.LARS
:docstring:
:members:
@@ -296,6 +300,10 @@
:docstring:
:members:

::: pytorch_optimizer.ScheduleFreeRAdam
:docstring:
:members:

::: pytorch_optimizer.StableAdamW
:docstring:
:members:
14 changes: 7 additions & 7 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -14,13 +14,13 @@ keywords = [
"AdaDelta", "AdaFactor", "AdaMax", "AdamG", "AdaMod", "AdaNorm", "AdaPNM", "AdaSmooth", "AdEMAMix", "ADOPT",
"AdaHessian", "Adai", "Adalite", "AdaLomo", "AdamMini", "AdamP", "AdamS", "Adan", "AggMo", "Aida", "AliG", "Amos",
"Apollo", "AvaGrad", "bSAM", "CAME", "DAdaptAdaGrad", "DAdaptAdam", "DAdaptAdan", "DAdaptSGD", "DAdaptLion",
"DeMo", "DiffGrad", "FAdam", "Fromage", "FTRL", "GaLore", "Gravity", "GrokFast", "GSAM", "Kate", "Lamb", "LARS",
"Lion", "LOMO", "Lookahead", "MADGRAD", "MSVAG", "Muno", "Nero", "NovoGrad", "PAdam", "PCGrad", "PID", "PNM",
"Prodigy", "QHAdam", "QHM", "RAdam", "Ranger", "Ranger21", "RotoGrad", "SAM", "ScheduleFreeSGD",
"ScheduleFreeAdamW", "SGDP", "Shampoo", "ScalableShampoo", "SGDW", "SignSGD", "SM3", "SOAP", "SopihaH", "SRMM",
"StableAdamW", "SWATS", "Tiger", "TRAC", "WSAM", "Yogi", "BCE", "BCEFocal", "Focal", "FocalCosine", "SoftF1",
"Dice", "LDAM", "Jaccard", "Bi-Tempered", "Tversky", "FocalTversky", "LovaszHinge", "bitsandbytes", "WSD",
"QGaLore",
"DeMo", "DiffGrad", "FAdam", "Fromage", "FTRL", "GaLore", "Gravity", "GrokFast", "GSAM", "Kate", "Lamb", "LaProp",
"LARS", "Lion", "LOMO", "Lookahead", "MADGRAD", "MSVAG", "Muno", "Nero", "NovoGrad", "PAdam", "PCGrad", "PID",
"PNM", "Prodigy", "QHAdam", "QHM", "RAdam", "Ranger", "Ranger21", "RotoGrad", "SAM", "ScheduleFreeSGD",
"ScheduleFreeAdamW", "ScheduleFreeRAdam", "SGDP", "Shampoo", "ScalableShampoo", "SGDW", "SignSGD", "SM3", "SOAP",
"SopihaH", "SRMM", "StableAdamW", "SWATS", "Tiger", "TRAC", "WSAM", "Yogi", "BCE", "BCEFocal", "Focal",
"FocalCosine", "SoftF1", "Dice", "LDAM", "Jaccard", "Bi-Tempered", "Tversky", "FocalTversky", "LovaszHinge",
"bitsandbytes", "WSD", "QGaLore",
]
classifiers = [
"License :: OSI Approved :: Apache Software License",
2 changes: 2 additions & 0 deletions pytorch_optimizer/__init__.py
Original file line number Diff line number Diff line change
@@ -107,6 +107,7 @@
GrokFastAdamW,
Kate,
Lamb,
LaProp,
Lion,
Lookahead,
Muon,
@@ -123,6 +124,7 @@
SafeFP16Optimizer,
ScalableShampoo,
ScheduleFreeAdamW,
ScheduleFreeRAdam,
ScheduleFreeSGD,
Shampoo,
SignSGD,
5 changes: 4 additions & 1 deletion pytorch_optimizer/optimizer/__init__.py
Original file line number Diff line number Diff line change
@@ -50,6 +50,7 @@
from pytorch_optimizer.optimizer.grokfast import GrokFastAdamW
from pytorch_optimizer.optimizer.kate import Kate
from pytorch_optimizer.optimizer.lamb import Lamb
from pytorch_optimizer.optimizer.laprop import LaProp
from pytorch_optimizer.optimizer.lars import LARS
from pytorch_optimizer.optimizer.lion import Lion
from pytorch_optimizer.optimizer.lomo import LOMO, AdaLOMO
@@ -71,7 +72,7 @@
from pytorch_optimizer.optimizer.ranger21 import Ranger21
from pytorch_optimizer.optimizer.rotograd import RotoGrad
from pytorch_optimizer.optimizer.sam import BSAM, GSAM, SAM, WSAM
from pytorch_optimizer.optimizer.schedulefree import ScheduleFreeAdamW, ScheduleFreeSGD
from pytorch_optimizer.optimizer.schedulefree import ScheduleFreeAdamW, ScheduleFreeRAdam, ScheduleFreeSGD
from pytorch_optimizer.optimizer.sgd import ASGD, SGDW, AccSGD, SignSGD
from pytorch_optimizer.optimizer.sgdp import SGDP
from pytorch_optimizer.optimizer.shampoo import ScalableShampoo, Shampoo
@@ -275,6 +276,8 @@ def load_optimizer(optimizer: str) -> OPTIMIZER:
FTRL,
DeMo,
Muon,
ScheduleFreeRAdam,
LaProp,
]
OPTIMIZERS: Dict[str, OPTIMIZER] = {str(optimizer.__name__).lower(): optimizer for optimizer in OPTIMIZER_LIST}

2 changes: 1 addition & 1 deletion pytorch_optimizer/optimizer/adalite.py
Original file line number Diff line number Diff line change
@@ -41,7 +41,7 @@ def __init__(
self.validate_betas(betas)
self.validate_non_negative(weight_decay, 'weight_decay')
self.validate_non_negative(eps1, 'eps1')
self.validate_non_negative(eps2, 'eps1')
self.validate_non_negative(eps2, 'eps2')

defaults: DEFAULTS = {
'lr': lr,
6 changes: 6 additions & 0 deletions pytorch_optimizer/optimizer/adamp.py
Original file line number Diff line number Diff line change
@@ -22,6 +22,7 @@ class AdamP(BaseOptimizer):
:param wd_ratio: float. relative weight decay applied on scale-invariant parameters compared to that applied
on scale-variant parameters.
:param use_gc: bool. use gradient centralization.
:param cautious: bool. whether to use the Cautious variant.
:param nesterov: bool. enables Nesterov momentum.
:param r: float. EMA factor. between 0.9 ~ 0.99 is preferred.
:param adanorm: bool. whether to use the AdaNorm variant.
@@ -40,6 +41,7 @@ def __init__(
delta: float = 0.1,
wd_ratio: float = 0.1,
use_gc: bool = False,
cautious: bool = False,
nesterov: bool = False,
r: float = 0.95,
adanorm: bool = False,
@@ -54,6 +56,7 @@ def __init__(
self.validate_non_negative(eps, 'eps')

self.use_gc = use_gc
self.cautious = cautious

defaults: DEFAULTS = {
'lr': lr,
@@ -170,6 +173,9 @@ def step(self, closure: CLOSURE = None) -> LOSS:
bias_correction1=bias_correction1,
)

if self.cautious:
self.apply_cautious(perturb, grad)

p.add_(perturb, alpha=-step_size)

return loss
11 changes: 10 additions & 1 deletion pytorch_optimizer/optimizer/adopt.py
Original file line number Diff line number Diff line change
@@ -17,6 +17,7 @@ class ADOPT(BaseOptimizer):
:param weight_decay: float. weight decay (L2 penalty).
:param weight_decouple: bool. the optimizer uses decoupled weight decay as in AdamW.
:param fixed_decay: bool. fix weight decay.
:param cautious: bool. whether to use the Cautious variant.
:param eps: float. term added to the denominator to improve numerical stability.
"""

@@ -29,6 +30,7 @@ def __init__(
weight_decay: float = 0.0,
weight_decouple: bool = False,
fixed_decay: bool = False,
cautious: bool = False,
eps: float = 1e-6,
**kwargs,
):
@@ -38,6 +40,7 @@ def __init__(
self.validate_non_negative(eps, 'eps')

self.clip_lambda = clip_lambda
self.cautious = cautious

defaults: DEFAULTS = {
'lr': lr,
@@ -118,6 +121,12 @@ def step(self, closure: CLOSURE = None) -> LOSS:

exp_avg.lerp_(normed_grad, weight=1.0 - beta1)

p.add_(exp_avg, alpha=-group['lr'])
if self.cautious:
update = exp_avg.clone()
self.apply_cautious(update, normed_grad)
else:
update = exp_avg

p.add_(update, alpha=-group['lr'])

return loss
Loading
Loading