Skip to content

Commit

Permalink
Merge pull request #252 from kozistr/feature/stableadamw-optimizer
Browse files Browse the repository at this point in the history
[Feature] Implement StableAdamW optimizer
  • Loading branch information
kozistr authored Jul 6, 2024
2 parents b316ef9 + 39a38f3 commit 5db0994
Show file tree
Hide file tree
Showing 11 changed files with 174 additions and 5 deletions.
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

**pytorch-optimizer** is optimizer & lr scheduler collections in PyTorch.
I just re-implemented (speed & memory tweaks, plug-ins) the algorithm while based on the original paper. Also, It includes useful and practical optimization ideas.
Currently, **70 optimizers (+ `bitsandbytes`)**, **16 lr schedulers**, and **13 loss functions** are supported!
Currently, **71 optimizers (+ `bitsandbytes`)**, **16 lr schedulers**, and **13 loss functions** are supported!

Highly inspired by [pytorch-optimizer](https://github.com/jettify/pytorch-optimizer).

Expand Down Expand Up @@ -167,6 +167,7 @@ supported_optimizers = get_supported_optimizers()
| FAdam | *Adam is a natural gradient optimizer using diagonal empirical Fisher information* | [github](https://github.com/lessw2020/fadam_pytorch) | <https://arxiv.org/abs/2405.12807> | [cite](https://ui.adsabs.harvard.edu/abs/2024arXiv240512807H/exportcitation) |
| Grokfast | *Accelerated Grokking by Amplifying Slow Gradients* | [github](https://github.com/ironjr/grokfast) | <https://arxiv.org/abs/2405.20233> | [cite](https://github.com/ironjr/grokfast?tab=readme-ov-file#citation) |
| Kate | *Remove that Square Root: A New Efficient Scale-Invariant Version of AdaGrad* | [github](https://github.com/nazya/KATE) | <https://arxiv.org/abs/2403.02648> | [cite](https://github.com/nazya/KATE?tab=readme-ov-file#remove-that-square-root-a-new-efficient-scale-invariant-version-of-adagrad) |
| StableAdamW | *Stable and low-precision training for large-scale vision-language models* | | <https://arxiv.org/abs/2304.13013> | [cite](https://ui.adsabs.harvard.edu/abs/2023arXiv230413013W/exportcitation) |

## Supported LR Scheduler

Expand Down
2 changes: 2 additions & 0 deletions docs/changelogs/v3.0.2.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
* Add more Pytorch built-in lr schedulers. (#248)
* Implement `Kate` optimizer. (#249, #251)
* [Remove that Square Root: A New Efficient Scale-Invariant Version of AdaGrad](https://arxiv.org/abs/2403.02648)
* Implement `StableAdamW` optimizer. (#250, #252)
* [Stable and low-precision training for large-scale vision-language models](https://arxiv.org/abs/2304.13013)

### Refactor

Expand Down
3 changes: 2 additions & 1 deletion docs/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

**pytorch-optimizer** is optimizer & lr scheduler collections in PyTorch.
I just re-implemented (speed & memory tweaks, plug-ins) the algorithm while based on the original paper. Also, It includes useful and practical optimization ideas.
Currently, **70 optimizers (+ `bitsandbytes`)**, **16 lr schedulers**, and **13 loss functions** are supported!
Currently, **71 optimizers (+ `bitsandbytes`)**, **16 lr schedulers**, and **13 loss functions** are supported!

Highly inspired by [pytorch-optimizer](https://github.com/jettify/pytorch-optimizer).

Expand Down Expand Up @@ -167,6 +167,7 @@ supported_optimizers = get_supported_optimizers()
| FAdam | *Adam is a natural gradient optimizer using diagonal empirical Fisher information* | [github](https://github.com/lessw2020/fadam_pytorch) | <https://arxiv.org/abs/2405.12807> | [cite](https://ui.adsabs.harvard.edu/abs/2024arXiv240512807H/exportcitation) |
| Grokfast | *Accelerated Grokking by Amplifying Slow Gradients* | [github](https://github.com/ironjr/grokfast) | <https://arxiv.org/abs/2405.20233> | [cite](https://github.com/ironjr/grokfast?tab=readme-ov-file#citation) |
| Kate | *Remove that Square Root: A New Efficient Scale-Invariant Version of AdaGrad* | [github](https://github.com/nazya/KATE) | <https://arxiv.org/abs/2403.02648> | [cite](https://github.com/nazya/KATE?tab=readme-ov-file#remove-that-square-root-a-new-efficient-scale-invariant-version-of-adagrad) |
| StableAdamW | *Stable and low-precision training for large-scale vision-language models* | | <https://arxiv.org/abs/2304.13013> | [cite](https://ui.adsabs.harvard.edu/abs/2023arXiv230413013W/exportcitation) |

## Supported LR Scheduler

Expand Down
4 changes: 4 additions & 0 deletions docs/optimizer.md
Original file line number Diff line number Diff line change
Expand Up @@ -268,6 +268,10 @@
:docstring:
:members:

::: pytorch_optimizer.StableAdamW
:docstring:
:members:

::: pytorch_optimizer.AccSGD
:docstring:
:members:
Expand Down
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@ keywords = [
"GrokFast", "GSAM", "Kate", "Lamb", "LARS", "Lion", "LOMO", "Lookahead", "MADGRAD", "MSVAG", "Nero", "NovoGrad",
"PAdam", "PCGrad", "PID", "PNM", "Prodigy", "QHAdam", "QHM", "RAdam", "Ranger", "Ranger21", "RotoGrad", "SAM",
"ScheduleFreeSGD", "ScheduleFreeAdamW", "SGDP", "Shampoo", "ScalableShampoo", "SGDW", "SignSGD", "SM3", "SopihaH",
"SRMM", "SWATS", "Tiger", "WSAM", "Yogi", "BCE", "BCEFocal", "Focal", "FocalCosine", "SoftF1", "Dice", "LDAM",
"Jaccard", "Bi-Tempered", "Tversky", "FocalTversky", "LovaszHinge", "bitsandbytes", "WSD",
"SRMM", "StableAdamW", "SWATS", "Tiger", "WSAM", "Yogi", "BCE", "BCEFocal", "Focal", "FocalCosine", "SoftF1",
"Dice", "LDAM", "Jaccard", "Bi-Tempered", "Tversky", "FocalTversky", "LovaszHinge", "bitsandbytes", "WSD",
]
classifiers = [
"License :: OSI Approved :: Apache Software License",
Expand Down
2 changes: 2 additions & 0 deletions pytorch_optimizer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
from pytorch_optimizer.optimizer.adamod import AdaMod
from pytorch_optimizer.optimizer.adamp import AdamP
from pytorch_optimizer.optimizer.adams import AdamS
from pytorch_optimizer.optimizer.adamw import StableAdamW
from pytorch_optimizer.optimizer.adan import Adan
from pytorch_optimizer.optimizer.adanorm import AdaNorm
from pytorch_optimizer.optimizer.adapnm import AdaPNM
Expand Down Expand Up @@ -201,6 +202,7 @@
FAdam,
GrokFastAdamW,
Kate,
StableAdamW,
]
OPTIMIZERS: Dict[str, OPTIMIZER] = {str(optimizer.__name__).lower(): optimizer for optimizer in OPTIMIZER_LIST}

Expand Down
135 changes: 135 additions & 0 deletions pytorch_optimizer/optimizer/adamw.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
import math

import torch
from torch.optim.optimizer import Optimizer

from pytorch_optimizer.base.exception import NoSparseGradientError
from pytorch_optimizer.base.optimizer import BaseOptimizer
from pytorch_optimizer.base.types import BETAS, CLOSURE, DEFAULTS, LOSS, PARAMETERS
from pytorch_optimizer.optimizer.utils import debias_beta


class StableAdamW(Optimizer, BaseOptimizer):
r"""Stable and low-precision training for large-scale vision-language models.
:param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.
:param lr: float. learning rate.
:param betas: BETAS. coefficients used for computing running averages of gradient and the squared hessian trace.
:param kahan_sum: bool. Enables Kahan summation for more accurate parameter updates when training in low precision
(float16 or bfloat16).
:param weight_decay: float. weight decay (L2 penalty).
:param weight_decouple: bool. decoupled weight decay.
:param eps: float. term added to the denominator to improve numerical stability.
"""

def __init__(
self,
params: PARAMETERS,
lr: float = 1e-3,
betas: BETAS = (0.9, 0.99),
kahan_sum: bool = True,
weight_decay: float = 1e-2,
weight_decouple: bool = True,
eps: float = 1e-8,
):
self.validate_learning_rate(lr)
self.validate_betas(betas)
self.validate_non_negative(weight_decay, 'weight_decay')
self.validate_non_negative(eps, 'eps')

defaults: DEFAULTS = {
'lr': lr,
'betas': betas,
'kahan_sum': kahan_sum,
'weight_decay': weight_decay,
'weight_decouple': weight_decouple,
'eps': eps,
}

super().__init__(params, defaults)

def __str__(self) -> str:
return 'StableAdamW'

@torch.no_grad()
def reset(self):
for group in self.param_groups:
group['step'] = 0
for p in group['params']:
state = self.state[p]

state['exp_avg'] = torch.zeros_like(p)
state['exp_avg_sq'] = torch.zeros_like(p)

state['kahan_comp'] = (
torch.zeros_like(p) if group['kahan_sum'] and p.dtype in {torch.float16, torch.bfloat16} else None
)

@torch.no_grad()
def step(self, closure: CLOSURE = None) -> LOSS:
loss: LOSS = None
if closure is not None:
with torch.enable_grad():
loss = closure()

for group in self.param_groups:
if 'step' in group:
group['step'] += 1
else:
group['step'] = 1

beta1, beta2 = group['betas']

beta1_comp: float = 1.0 - debias_beta(beta1, group['step'])
beta2_hat: float = debias_beta(beta2, group['step'])

eps_p2: float = math.pow(group['eps'], 2)

for p in group['params']:
if p.grad is None:
continue

grad = p.grad
if grad.is_sparse:
raise NoSparseGradientError(str(self))

state = self.state[p]
if len(state) == 0:
state['exp_avg'] = torch.zeros_like(p)
state['exp_avg_sq'] = torch.zeros_like(p)

state['kahan_comp'] = (
torch.zeros_like(p)
if (group['kahan_sum'] and p.dtype in {torch.float16, torch.bfloat16})
else None
)

exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
exp_avg.lerp_(grad, weight=beta1_comp)
exp_avg_sq.mul_(beta2_hat).addcmul_(grad, grad, value=1.0 - beta2_hat)

rms = grad.pow(2).div_(exp_avg_sq.clip(min=eps_p2)).mean().sqrt_()

lr = group['lr'] / rms.clip(min=1.0)

self.apply_weight_decay(
p,
p.grad,
lr=lr,
weight_decay=group['weight_decay'],
weight_decouple=group['weight_decouple'],
fixed_decay=False,
)

if group['kahan_sum'] and p.dtype in {torch.float16, torch.bfloat16}:
kahan_comp = state['kahan_comp']
kahan_comp.addcdiv_(exp_avg, exp_avg_sq.sqrt().add_(group['eps']), value=-lr)

grad.copy_(p.detach())
p.add_(kahan_comp)

kahan_comp.add_(grad.sub_(p))
else:
p.addcdiv_(exp_avg, exp_avg_sq.sqrt().add_(group['eps']), value=-lr)

return loss
11 changes: 11 additions & 0 deletions pytorch_optimizer/optimizer/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,17 @@
from pytorch_optimizer.base.types import PARAMETERS


def debias_beta(beta: float, step: int) -> float:
r"""Apply the Adam-style debias correction into beta.
Simplified version of `\^{beta} = beta * (1.0 - beta ** (step - 1)) / (1.0 - beta ** step)`
:param beta: float. beta.
:param step: int. number of step.
"""
return (beta ** step - beta) / (beta ** step - 1.0) # fmt: skip


def is_valid_parameters(parameters: PARAMETERS) -> bool:
r"""Check where the parameters are valid."""
return isinstance(parameters, (list, tuple)) and len(parameters) > 0 and isinstance(parameters[0], dict)
Expand Down
3 changes: 3 additions & 0 deletions tests/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@
Shampoo,
SignSGD,
SophiaH,
StableAdamW,
Tiger,
Yogi,
)
Expand Down Expand Up @@ -132,6 +133,7 @@
'schedulefreeadamw',
'fadam',
'grokfastadamw',
'stableadamw',
]

VALID_LR_SCHEDULER_NAMES: List[str] = [
Expand Down Expand Up @@ -463,6 +465,7 @@
(FAdam, {'lr': 1e0, 'weight_decay': 1e-3}, 5),
(GrokFastAdamW, {'lr': 1e0, 'weight_decay': 1e-3}, 10),
(Kate, {'lr': 5e-2}, 10),
(StableAdamW, {'lr': 1e0}, 5),
]
ADANORM_SUPPORTED_OPTIMIZERS: List[Tuple[Any, Dict[str, Union[float, bool, int]], int]] = [
(AdaBelief, {'lr': 5e-1, 'weight_decay': 1e-3, 'adanorm': True}, 10),
Expand Down
2 changes: 1 addition & 1 deletion tests/test_load_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def test_load_lr_scheduler_invalid(invalid_lr_scheduler_names):


def test_get_supported_optimizers():
assert len(get_supported_optimizers()) == 69
assert len(get_supported_optimizers()) == 70


def test_get_supported_lr_schedulers():
Expand Down
10 changes: 10 additions & 0 deletions tests/test_optimizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -640,3 +640,13 @@ def test_grokfast_ema(environment):
model.fc2.bias.grad = torch.randn(1)

_ = gradfilter_ema(model, None)


def test_stableadamw_optimizer(environment):
_, model, _ = environment

model.fc1.weight.data = torch.randn(2, 2, dtype=torch.float16)

optimizer = load_optimizer('StableAdamW')(model.parameters())
optimizer.reset()
optimizer.step()

0 comments on commit 5db0994

Please sign in to comment.