From e8c74beaa9e9f997d874054f8f1d338d782905fc Mon Sep 17 00:00:00 2001 From: kozistr Date: Sat, 6 Jul 2024 15:41:07 +0900 Subject: [PATCH 1/7] docs: v3.0.2 changelog --- docs/changelogs/v3.0.2.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/docs/changelogs/v3.0.2.md b/docs/changelogs/v3.0.2.md index 2b2fc7029..dd9764e40 100644 --- a/docs/changelogs/v3.0.2.md +++ b/docs/changelogs/v3.0.2.md @@ -7,6 +7,8 @@ * Add more Pytorch built-in lr schedulers. (#248) * Implement `Kate` optimizer. (#249, #251) * [Remove that Square Root: A New Efficient Scale-Invariant Version of AdaGrad](https://arxiv.org/abs/2403.02648) +* Implement `StableAdamW` optimizer. (#250, #252) + * [Stable and low-precision training for large-scale vision-language models](https://arxiv.org/abs/2304.13013) ### Refactor From 9fff0303c185d8da28f50e3eb6d8874e4f3489a7 Mon Sep 17 00:00:00 2001 From: kozistr Date: Sat, 6 Jul 2024 15:41:19 +0900 Subject: [PATCH 2/7] docs: StableAdamW optimizer --- README.md | 3 ++- docs/index.md | 3 ++- docs/optimizer.md | 4 ++++ 3 files changed, 8 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 1f89c5d13..e385dfbfa 100644 --- a/README.md +++ b/README.md @@ -10,7 +10,7 @@ **pytorch-optimizer** is optimizer & lr scheduler collections in PyTorch. I just re-implemented (speed & memory tweaks, plug-ins) the algorithm while based on the original paper. Also, It includes useful and practical optimization ideas. -Currently, **70 optimizers (+ `bitsandbytes`)**, **16 lr schedulers**, and **13 loss functions** are supported! +Currently, **71 optimizers (+ `bitsandbytes`)**, **16 lr schedulers**, and **13 loss functions** are supported! Highly inspired by [pytorch-optimizer](https://github.com/jettify/pytorch-optimizer). @@ -167,6 +167,7 @@ supported_optimizers = get_supported_optimizers() | FAdam | *Adam is a natural gradient optimizer using diagonal empirical Fisher information* | [github](https://github.com/lessw2020/fadam_pytorch) | | [cite](https://ui.adsabs.harvard.edu/abs/2024arXiv240512807H/exportcitation) | | Grokfast | *Accelerated Grokking by Amplifying Slow Gradients* | [github](https://github.com/ironjr/grokfast) | | [cite](https://github.com/ironjr/grokfast?tab=readme-ov-file#citation) | | Kate | *Remove that Square Root: A New Efficient Scale-Invariant Version of AdaGrad* | [github](https://github.com/nazya/KATE) | | [cite](https://github.com/nazya/KATE?tab=readme-ov-file#remove-that-square-root-a-new-efficient-scale-invariant-version-of-adagrad) | +| StableAdamW | *Stable and low-precision training for large-scale vision-language models* | | | [cite](https://ui.adsabs.harvard.edu/abs/2023arXiv230413013W/exportcitation) | ## Supported LR Scheduler diff --git a/docs/index.md b/docs/index.md index 1f89c5d13..e385dfbfa 100644 --- a/docs/index.md +++ b/docs/index.md @@ -10,7 +10,7 @@ **pytorch-optimizer** is optimizer & lr scheduler collections in PyTorch. I just re-implemented (speed & memory tweaks, plug-ins) the algorithm while based on the original paper. Also, It includes useful and practical optimization ideas. -Currently, **70 optimizers (+ `bitsandbytes`)**, **16 lr schedulers**, and **13 loss functions** are supported! +Currently, **71 optimizers (+ `bitsandbytes`)**, **16 lr schedulers**, and **13 loss functions** are supported! Highly inspired by [pytorch-optimizer](https://github.com/jettify/pytorch-optimizer). @@ -167,6 +167,7 @@ supported_optimizers = get_supported_optimizers() | FAdam | *Adam is a natural gradient optimizer using diagonal empirical Fisher information* | [github](https://github.com/lessw2020/fadam_pytorch) | | [cite](https://ui.adsabs.harvard.edu/abs/2024arXiv240512807H/exportcitation) | | Grokfast | *Accelerated Grokking by Amplifying Slow Gradients* | [github](https://github.com/ironjr/grokfast) | | [cite](https://github.com/ironjr/grokfast?tab=readme-ov-file#citation) | | Kate | *Remove that Square Root: A New Efficient Scale-Invariant Version of AdaGrad* | [github](https://github.com/nazya/KATE) | | [cite](https://github.com/nazya/KATE?tab=readme-ov-file#remove-that-square-root-a-new-efficient-scale-invariant-version-of-adagrad) | +| StableAdamW | *Stable and low-precision training for large-scale vision-language models* | | | [cite](https://ui.adsabs.harvard.edu/abs/2023arXiv230413013W/exportcitation) | ## Supported LR Scheduler diff --git a/docs/optimizer.md b/docs/optimizer.md index 3d74a4365..7af635eab 100644 --- a/docs/optimizer.md +++ b/docs/optimizer.md @@ -268,6 +268,10 @@ :docstring: :members: +::: pytorch_optimizer.StableAdamW + :docstring: + :members: + ::: pytorch_optimizer.AccSGD :docstring: :members: From b31d232c127a820fb7660aeb29fa75c09a05ceec Mon Sep 17 00:00:00 2001 From: kozistr Date: Sat, 6 Jul 2024 15:41:29 +0900 Subject: [PATCH 3/7] feature: debias_beta --- pytorch_optimizer/optimizer/utils.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/pytorch_optimizer/optimizer/utils.py b/pytorch_optimizer/optimizer/utils.py index 50cdb09ef..87a281de6 100644 --- a/pytorch_optimizer/optimizer/utils.py +++ b/pytorch_optimizer/optimizer/utils.py @@ -12,6 +12,14 @@ from pytorch_optimizer.base.types import PARAMETERS +def debias_beta(beta: float, step: int) -> float: + r"""Applies the Adam-style debias correction into beta. + + Simplified version of `\^{beta} = beta * (1.0 - beta ** (step - 1)) / (1.0 - beta ** step)` + """ + return (beta**step - beta) / (beta**step - 1.0) + + def is_valid_parameters(parameters: PARAMETERS) -> bool: r"""Check where the parameters are valid.""" return isinstance(parameters, (list, tuple)) and len(parameters) > 0 and isinstance(parameters[0], dict) From 0ece04f972216a21a29ac334abc33c8874b2b733 Mon Sep 17 00:00:00 2001 From: kozistr Date: Sat, 6 Jul 2024 15:41:37 +0900 Subject: [PATCH 4/7] feature: StableAdamW optimizer --- pytorch_optimizer/optimizer/adamw.py | 135 +++++++++++++++++++++++++++ 1 file changed, 135 insertions(+) create mode 100644 pytorch_optimizer/optimizer/adamw.py diff --git a/pytorch_optimizer/optimizer/adamw.py b/pytorch_optimizer/optimizer/adamw.py new file mode 100644 index 000000000..02e4af5ca --- /dev/null +++ b/pytorch_optimizer/optimizer/adamw.py @@ -0,0 +1,135 @@ +import math + +import torch +from torch.optim.optimizer import Optimizer + +from pytorch_optimizer.base.exception import NoSparseGradientError +from pytorch_optimizer.base.optimizer import BaseOptimizer +from pytorch_optimizer.base.types import BETAS, CLOSURE, DEFAULTS, LOSS, PARAMETERS +from pytorch_optimizer.optimizer.utils import debias_beta + + +class StableAdamW(Optimizer, BaseOptimizer): + r"""Stable and low-precision training for large-scale vision-language models. + + :param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups. + :param lr: float. learning rate. + :param betas: BETAS. coefficients used for computing running averages of gradient and the squared hessian trace. + :param kahan_sum: bool. Enables Kahan summation for more accurate parameter updates when training in low precision + (float16 or bfloat16). + :param weight_decay: float. weight decay (L2 penalty). + :param weight_decouple: bool. decoupled weight decay. + :param eps: float. term added to the denominator to improve numerical stability. + """ + + def __init__( + self, + params: PARAMETERS, + lr: float = 1e-3, + betas: BETAS = (0.9, 0.99), + kahan_sum: bool = True, + weight_decay: float = 1e-2, + weight_decouple: bool = True, + eps: float = 1e-8, + ): + self.validate_learning_rate(lr) + self.validate_betas(betas) + self.validate_non_negative(weight_decay, 'weight_decay') + self.validate_non_negative(eps, 'eps') + + defaults: DEFAULTS = { + 'lr': lr, + 'betas': betas, + 'kahan_sum': kahan_sum, + 'weight_decay': weight_decay, + 'weight_decouple': weight_decouple, + 'eps': eps, + } + + super().__init__(params, defaults) + + def __str__(self) -> str: + return 'StableAdamW' + + @torch.no_grad() + def reset(self): + for group in self.param_groups: + group['step'] = 0 + for p in group['params']: + state = self.state[p] + + state['exp_avg'] = torch.zeros_like(p) + state['exp_avg_sq'] = torch.zeros_like(p) + + state['kahan_comp'] = ( + torch.zeros_like(p) if group['kahan_sum'] and p.dtype in {torch.float16, torch.bfloat16} else None + ) + + @torch.no_grad() + def step(self, closure: CLOSURE = None) -> LOSS: + loss: LOSS = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + for group in self.param_groups: + if 'step' in group: + group['step'] += 1 + else: + group['step'] = 1 + + beta1, beta2 = group['betas'] + + beta1_comp: float = 1.0 - debias_beta(beta1, group['step']) + beta2_hat: float = debias_beta(beta2, group['step']) + + eps_p2: float = math.pow(group['eps'], 2) + + for p in group['params']: + if p.grad is None: + continue + + grad = p.grad + if grad.is_sparse: + raise NoSparseGradientError(str(self)) + + state = self.state[p] + if len(state) == 0: + state['exp_avg'] = torch.zeros_like(p) + state['exp_avg_sq'] = torch.zeros_like(p) + + state['kahan_comp'] = ( + torch.zeros_like(p) + if (group['kahan_sum'] and p.dtype in {torch.float16, torch.bfloat16}) + else None + ) + + exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] + exp_avg.lerp_(grad, weight=beta1_comp) + exp_avg_sq.mul_(beta2_hat).addcmul_(grad, grad, value=1.0 - beta2_hat) + + rms = grad.pow(2).div_(exp_avg_sq.clip(min=eps_p2)).mean().sqrt_() + + lr = group['lr'] / rms.clip(min=1.0) + + self.apply_weight_decay( + p, + p.grad, + lr=lr, + weight_decay=group['weight_decay'], + weight_decouple=group['weight_decouple'], + fixed_decay=False, + ) + + if group['kahan_sum'] and p.dtype in {torch.float16, torch.bfloat16}: + kahan_comp = state['kahan_comp'] + kahan_comp.addcdiv_(exp_avg, exp_avg_sq.sqrt().add_(group['eps']), value=-lr) + + grad.copy_(p.detach()) + p.add_(kahan_comp) + + kahan_comp.add_(grad.sub_(p)) + else: + p.addcdiv_(exp_avg, exp_avg_sq.sqrt().add_(group['eps']), value=-lr) + + return loss From cfe887f08f1ad306d63300dfce0d6ef54d367501 Mon Sep 17 00:00:00 2001 From: kozistr Date: Sat, 6 Jul 2024 15:41:43 +0900 Subject: [PATCH 5/7] chore: keyword --- pyproject.toml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index f00a6429d..8bcff206b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -17,8 +17,8 @@ keywords = [ "GrokFast", "GSAM", "Kate", "Lamb", "LARS", "Lion", "LOMO", "Lookahead", "MADGRAD", "MSVAG", "Nero", "NovoGrad", "PAdam", "PCGrad", "PID", "PNM", "Prodigy", "QHAdam", "QHM", "RAdam", "Ranger", "Ranger21", "RotoGrad", "SAM", "ScheduleFreeSGD", "ScheduleFreeAdamW", "SGDP", "Shampoo", "ScalableShampoo", "SGDW", "SignSGD", "SM3", "SopihaH", - "SRMM", "SWATS", "Tiger", "WSAM", "Yogi", "BCE", "BCEFocal", "Focal", "FocalCosine", "SoftF1", "Dice", "LDAM", - "Jaccard", "Bi-Tempered", "Tversky", "FocalTversky", "LovaszHinge", "bitsandbytes", "WSD", + "SRMM", "StableAdamW", "SWATS", "Tiger", "WSAM", "Yogi", "BCE", "BCEFocal", "Focal", "FocalCosine", "SoftF1", + "Dice", "LDAM", "Jaccard", "Bi-Tempered", "Tversky", "FocalTversky", "LovaszHinge", "bitsandbytes", "WSD", ] classifiers = [ "License :: OSI Approved :: Apache Software License", From b7ddc4a0587fef438f809fb5de65dcfb3dc986f2 Mon Sep 17 00:00:00 2001 From: kozistr Date: Sat, 6 Jul 2024 15:41:53 +0900 Subject: [PATCH 6/7] update: StableAdamW optimizer --- pytorch_optimizer/__init__.py | 2 ++ tests/constants.py | 3 +++ tests/test_load_modules.py | 2 +- tests/test_optimizers.py | 10 ++++++++++ 4 files changed, 16 insertions(+), 1 deletion(-) diff --git a/pytorch_optimizer/__init__.py b/pytorch_optimizer/__init__.py index 340232fe6..3e1569a49 100644 --- a/pytorch_optimizer/__init__.py +++ b/pytorch_optimizer/__init__.py @@ -44,6 +44,7 @@ from pytorch_optimizer.optimizer.adamod import AdaMod from pytorch_optimizer.optimizer.adamp import AdamP from pytorch_optimizer.optimizer.adams import AdamS +from pytorch_optimizer.optimizer.adamw import StableAdamW from pytorch_optimizer.optimizer.adan import Adan from pytorch_optimizer.optimizer.adanorm import AdaNorm from pytorch_optimizer.optimizer.adapnm import AdaPNM @@ -201,6 +202,7 @@ FAdam, GrokFastAdamW, Kate, + StableAdamW, ] OPTIMIZERS: Dict[str, OPTIMIZER] = {str(optimizer.__name__).lower(): optimizer for optimizer in OPTIMIZER_LIST} diff --git a/tests/constants.py b/tests/constants.py index 7daf9d6a8..55a2e0ba8 100644 --- a/tests/constants.py +++ b/tests/constants.py @@ -67,6 +67,7 @@ Shampoo, SignSGD, SophiaH, + StableAdamW, Tiger, Yogi, ) @@ -132,6 +133,7 @@ 'schedulefreeadamw', 'fadam', 'grokfastadamw', + 'stableadamw', ] VALID_LR_SCHEDULER_NAMES: List[str] = [ @@ -463,6 +465,7 @@ (FAdam, {'lr': 1e0, 'weight_decay': 1e-3}, 5), (GrokFastAdamW, {'lr': 1e0, 'weight_decay': 1e-3}, 10), (Kate, {'lr': 5e-2}, 10), + (StableAdamW, {'lr': 1e0}, 5), ] ADANORM_SUPPORTED_OPTIMIZERS: List[Tuple[Any, Dict[str, Union[float, bool, int]], int]] = [ (AdaBelief, {'lr': 5e-1, 'weight_decay': 1e-3, 'adanorm': True}, 10), diff --git a/tests/test_load_modules.py b/tests/test_load_modules.py index d13f0877b..bde4f45bc 100644 --- a/tests/test_load_modules.py +++ b/tests/test_load_modules.py @@ -38,7 +38,7 @@ def test_load_lr_scheduler_invalid(invalid_lr_scheduler_names): def test_get_supported_optimizers(): - assert len(get_supported_optimizers()) == 69 + assert len(get_supported_optimizers()) == 70 def test_get_supported_lr_schedulers(): diff --git a/tests/test_optimizers.py b/tests/test_optimizers.py index e0abec4e2..97b195144 100644 --- a/tests/test_optimizers.py +++ b/tests/test_optimizers.py @@ -640,3 +640,13 @@ def test_grokfast_ema(environment): model.fc2.bias.grad = torch.randn(1) _ = gradfilter_ema(model, None) + + +def test_stableadamw_optimizer(environment): + _, model, _ = environment + + model.fc1.weight.data = torch.randn(2, 2, dtype=torch.float16) + + optimizer = load_optimizer('StableAdamW')(model.parameters()) + optimizer.reset() + optimizer.step() From 39a38f3397f201ae80473f6a7f46af1d10a13d22 Mon Sep 17 00:00:00 2001 From: kozistr Date: Sat, 6 Jul 2024 15:44:11 +0900 Subject: [PATCH 7/7] style: fix D401 --- pytorch_optimizer/optimizer/utils.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/pytorch_optimizer/optimizer/utils.py b/pytorch_optimizer/optimizer/utils.py index 87a281de6..084f3576a 100644 --- a/pytorch_optimizer/optimizer/utils.py +++ b/pytorch_optimizer/optimizer/utils.py @@ -13,11 +13,14 @@ def debias_beta(beta: float, step: int) -> float: - r"""Applies the Adam-style debias correction into beta. + r"""Apply the Adam-style debias correction into beta. Simplified version of `\^{beta} = beta * (1.0 - beta ** (step - 1)) / (1.0 - beta ** step)` + + :param beta: float. beta. + :param step: int. number of step. """ - return (beta**step - beta) / (beta**step - 1.0) + return (beta ** step - beta) / (beta ** step - 1.0) # fmt: skip def is_valid_parameters(parameters: PARAMETERS) -> bool: