From 79066f94115eb7704ff7d0e340dd07ea88aae75c Mon Sep 17 00:00:00 2001 From: kozistr Date: Wed, 4 Dec 2024 17:58:45 +0900 Subject: [PATCH 01/14] docs: v3.3.1 changelog --- docs/changelogs/v3.3.1.md | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/docs/changelogs/v3.3.1.md b/docs/changelogs/v3.3.1.md index f6980d63..eab3c666 100644 --- a/docs/changelogs/v3.3.1.md +++ b/docs/changelogs/v3.3.1.md @@ -6,3 +6,7 @@ * [Decoupled Momentum Optimization](https://arxiv.org/abs/2411.19870) * Implement `Muon` optimizer. (#302) * [MomentUm Orthogonalized by Newton-schulz](https://github.com/KellerJordan/Muon) +* Implement `ScheduleFreeRAdam` optimizer. (#304) +* Implement `LaProp` optimizer. (#304) + * [Separating Momentum and Adaptivity in Adam](https://arxiv.org/abs/2002.04839) +* Support `Cautious` variant to `LaProp`, `AdamP`. (#304). From 3a47e96703859e426548fa7abc3fb7c6d3098a63 Mon Sep 17 00:00:00 2001 From: kozistr Date: Wed, 4 Dec 2024 17:58:56 +0900 Subject: [PATCH 02/14] feature: implement LaProp optimizer --- pytorch_optimizer/optimizer/laprop.py | 162 ++++++++++++++++++++++++++ 1 file changed, 162 insertions(+) create mode 100644 pytorch_optimizer/optimizer/laprop.py diff --git a/pytorch_optimizer/optimizer/laprop.py b/pytorch_optimizer/optimizer/laprop.py new file mode 100644 index 00000000..cd441242 --- /dev/null +++ b/pytorch_optimizer/optimizer/laprop.py @@ -0,0 +1,162 @@ +import math + +import torch + +from pytorch_optimizer.base.exception import NoSparseGradientError +from pytorch_optimizer.base.optimizer import BaseOptimizer +from pytorch_optimizer.base.types import BETAS, CLOSURE, DEFAULTS, LOSS, PARAMETERS + + +class LaProp(BaseOptimizer): + r"""Separating Momentum and Adaptivity in Adam. + + :param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups. + :param lr: float. learning rate. + :param betas: BETAS. coefficients used for computing running averages of gradient and the squared hessian trace. + :param centered: bool. + :param weight_decay: float. weight decay (L2 penalty). + :param weight_decouple: bool. the optimizer uses decoupled weight decay as in AdamW. + :param fixed_decay: bool. fix weight decay. + :param ams_bound: bool. whether to use the AMSBound variant. + :param cautious: bool. whether to use the Cautious variant. + :param eps: float. epsilon value. + """ + + def __init__( + self, + params: PARAMETERS, + lr: float = 4e-4, + betas: BETAS = (0.9, 0.999), + centered: bool = False, + steps_before_using_centered: int = 10, + weight_decay: float = 0.0, + weight_decouple: bool = True, + fixed_decay: bool = False, + ams_bound: bool = False, + cautious: bool = False, + eps: float = 1e-15, + **kwargs, + ): + self.validate_learning_rate(lr) + self.validate_betas(betas) + self.validate_non_negative(weight_decay, 'weight_decay') + self.validate_non_negative(eps, 'eps') + + self.cautious = cautious + self.steps_before_using_centered: int = steps_before_using_centered + + defaults: DEFAULTS = { + 'lr': lr, + 'betas': betas, + 'centered': centered, + 'weight_decay': weight_decay, + 'weight_decouple': weight_decouple, + 'fixed_decay': fixed_decay, + 'ams_bound': ams_bound, + 'eps': eps, + } + + super().__init__(params, defaults) + + def __str__(self) -> str: + return 'LaProp' + + @torch.no_grad() + def reset(self): + for group in self.param_groups: + group['step'] = 0 + for p in group['params']: + state = self.state[p] + + state['exp_avg'] = torch.zeros_like(p) + state['exp_avg_sq'] = torch.zeros_like(p) + state['exp_avg_lr_1'] = 0.0 + state['exp_avg_lr_2'] = 0.0 + + if group['centered']: + state['exp_mean_avg_beta2'] = torch.zeros_like(p) + if group['ams_bound']: + state['max_exp_avg_sq'] = torch.zeros_like(p) + + @torch.no_grad() + def step(self, closure: CLOSURE = None) -> LOSS: + loss: LOSS = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + for group in self.param_groups: + if 'step' in group: + group['step'] += 1 + else: + group['step'] = 1 + + beta1, beta2 = group['betas'] + + bias_correction2_sq: float = math.sqrt(self.debias(beta2, group['step'])) + + for p in group['params']: + if p.grad is None: + continue + + grad = p.grad + if grad.is_sparse: + raise NoSparseGradientError(str(self)) + + state = self.state[p] + + if len(state) == 0: + state['exp_avg'] = torch.zeros_like(p) + state['exp_avg_sq'] = torch.zeros_like(p) + state['exp_avg_lr_1'] = 0.0 + state['exp_avg_lr_2'] = 0.0 + + if group['centered']: + state['exp_mean_avg_beta2'] = torch.zeros_like(p) + if group['ams_bound']: + state['max_exp_avg_sq'] = torch.zeros_like(p) + + exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] + exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1.0 - beta2) + + state['exp_avg_lr_1'] = state['exp_avg_lr_1'] * beta1 + (1.0 - beta1) * group['lr'] + state['exp_avg_lr_2'] = state['exp_avg_lr_2'] * beta2 + (1.0 - beta2) + + bias_correction1: float = state['exp_avg_lr_1'] / group['lr'] if group['lr'] != 0.0 else 1.0 + step_size: float = 1.0 / bias_correction1 + + de_nom = exp_avg_sq + if group['centered']: + exp_mean_avg_beta2 = state['exp_mean_avg_beta2'] + exp_mean_avg_beta2.mul_(beta2).add_(grad, alpha=1.0 - beta2) + if group['step'] > self.steps_before_using_centered: + de_nom -= exp_mean_avg_beta2.pow(2) + + de_nom = self.apply_ams_bound( + ams_bound=group['ams_bound'], + exp_avg_sq=exp_avg_sq, + max_exp_avg_sq=state.get('max_exp_avg_sq', None), + eps=group['eps'], + ) + de_nom.div_(bias_correction2_sq) + + exp_avg.mul_(beta1).addcdiv_(grad, de_nom, value=(1.0 - beta1) * group['lr']) + + if self.cautious: + update = exp_avg.clone() + self.apply_cautious(update, grad) + else: + update = exp_avg + + p.add_(update, alpha=-step_size) + + self.apply_weight_decay( + p=p, + grad=p.grad, + lr=group['lr'], + weight_decay=group['weight_decay'], + weight_decouple=group['weight_decouple'], + fixed_decay=group['fixed_decay'], + ) + + return loss From edf6247db7e13c8c89e13bdae534ce275e412c6e Mon Sep 17 00:00:00 2001 From: kozistr Date: Wed, 4 Dec 2024 17:59:04 +0900 Subject: [PATCH 03/14] docs: README --- README.md | 11 ++++++++--- docs/index.md | 11 ++++++++--- 2 files changed, 16 insertions(+), 6 deletions(-) diff --git a/README.md b/README.md index 587f389e..d565c09f 100644 --- a/README.md +++ b/README.md @@ -8,9 +8,13 @@ | Status | [![PyPi download](https://static.pepy.tech/badge/pytorch-optimizer)](https://pepy.tech/project/pytorch-optimizer) [![PyPi month download](https://static.pepy.tech/badge/pytorch-optimizer/month)](https://pepy.tech/project/pytorch-optimizer) | | License | [![apache](https://img.shields.io/badge/License-Apache_2.0-blue.svg)](https://opensource.org/licenses/Apache-2.0) | -**pytorch-optimizer** is optimizer & lr scheduler collections in PyTorch. -I just re-implemented (speed & memory tweaks, plug-ins) the algorithm while based on the original paper. Also, It includes useful and practical optimization ideas. -Currently, **81 optimizers (+ `bitsandbytes`, `qgalore`, `torchao`)**, **16 lr schedulers**, and **13 loss functions** are supported! +## The reasons why you use `pytorch-optimizer`. + +1. Wide range of supported optimizers. Currently, **83 optimizers (+ `bitsandbytes`, `qgalore`, `torchao`)**, **16 lr schedulers**, and **13 loss functions** are supported! +2. Including many variants such as `Cautious`, `AdamD`, `Gradient Centrailiaztion` +3. Easy to use, clean, and tested codes +4. Active maintenance +5. Somewhat a bit more optimized compared to the original implementation Highly inspired by [pytorch-optimizer](https://github.com/jettify/pytorch-optimizer). @@ -187,6 +191,7 @@ get_supported_optimizers(['adam*', 'ranger*']) | DeMo | *Decoupled Momentum Optimization* | [github](https://github.com/bloc97/DeMo) | | [cite](https://ui.adsabs.harvard.edu/abs/2024arXiv241119870P/exportcitation) | | MicroAdam | *Accurate Adaptive Optimization with Low Space Overhead and Provable Convergence* | [github](https://github.com/IST-DASLab/MicroAdam) | | [cite](https://github.com/IST-DASLab/MicroAdam?tab=readme-ov-file#citing) | | Muon | *MomentUm Orthogonalized by Newton-schulz* | [github](https://github.com/KellerJordan/Muon) | | [cite](https://github.com/KellerJordan/Muon) | +| LaProp | *Separating Momentum and Adaptivity in Adam* | [github](https://github.com/Z-T-WANG/LaProp-Optimizer) | | [cite](https://github.com/Z-T-WANG/LaProp-Optimizer?tab=readme-ov-file#citation) | ## Supported LR Scheduler diff --git a/docs/index.md b/docs/index.md index 587f389e..d565c09f 100644 --- a/docs/index.md +++ b/docs/index.md @@ -8,9 +8,13 @@ | Status | [![PyPi download](https://static.pepy.tech/badge/pytorch-optimizer)](https://pepy.tech/project/pytorch-optimizer) [![PyPi month download](https://static.pepy.tech/badge/pytorch-optimizer/month)](https://pepy.tech/project/pytorch-optimizer) | | License | [![apache](https://img.shields.io/badge/License-Apache_2.0-blue.svg)](https://opensource.org/licenses/Apache-2.0) | -**pytorch-optimizer** is optimizer & lr scheduler collections in PyTorch. -I just re-implemented (speed & memory tweaks, plug-ins) the algorithm while based on the original paper. Also, It includes useful and practical optimization ideas. -Currently, **81 optimizers (+ `bitsandbytes`, `qgalore`, `torchao`)**, **16 lr schedulers**, and **13 loss functions** are supported! +## The reasons why you use `pytorch-optimizer`. + +1. Wide range of supported optimizers. Currently, **83 optimizers (+ `bitsandbytes`, `qgalore`, `torchao`)**, **16 lr schedulers**, and **13 loss functions** are supported! +2. Including many variants such as `Cautious`, `AdamD`, `Gradient Centrailiaztion` +3. Easy to use, clean, and tested codes +4. Active maintenance +5. Somewhat a bit more optimized compared to the original implementation Highly inspired by [pytorch-optimizer](https://github.com/jettify/pytorch-optimizer). @@ -187,6 +191,7 @@ get_supported_optimizers(['adam*', 'ranger*']) | DeMo | *Decoupled Momentum Optimization* | [github](https://github.com/bloc97/DeMo) | | [cite](https://ui.adsabs.harvard.edu/abs/2024arXiv241119870P/exportcitation) | | MicroAdam | *Accurate Adaptive Optimization with Low Space Overhead and Provable Convergence* | [github](https://github.com/IST-DASLab/MicroAdam) | | [cite](https://github.com/IST-DASLab/MicroAdam?tab=readme-ov-file#citing) | | Muon | *MomentUm Orthogonalized by Newton-schulz* | [github](https://github.com/KellerJordan/Muon) | | [cite](https://github.com/KellerJordan/Muon) | +| LaProp | *Separating Momentum and Adaptivity in Adam* | [github](https://github.com/Z-T-WANG/LaProp-Optimizer) | | [cite](https://github.com/Z-T-WANG/LaProp-Optimizer?tab=readme-ov-file#citation) | ## Supported LR Scheduler From 5060bd3c5d9242f9d8a4800202e4a6e8f4d21d3b Mon Sep 17 00:00:00 2001 From: kozistr Date: Wed, 4 Dec 2024 17:59:12 +0900 Subject: [PATCH 04/14] docs: LaProp optimizer --- docs/optimizer.md | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/docs/optimizer.md b/docs/optimizer.md index 8b5289be..887f5e48 100644 --- a/docs/optimizer.md +++ b/docs/optimizer.md @@ -204,6 +204,10 @@ :docstring: :members: +::: pytorch_optimizer.LaProp + :docstring: + :members: + ::: pytorch_optimizer.LARS :docstring: :members: @@ -296,6 +300,10 @@ :docstring: :members: +::: pytorch_optimizer.ScheduleFreeRAdam + :docstring: + :members: + ::: pytorch_optimizer.StableAdamW :docstring: :members: From 8b6f531075aaeb51a45ef63aab88e729d3773c4e Mon Sep 17 00:00:00 2001 From: kozistr Date: Wed, 4 Dec 2024 17:59:20 +0900 Subject: [PATCH 05/14] fix: eps2 validation --- pytorch_optimizer/optimizer/adalite.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_optimizer/optimizer/adalite.py b/pytorch_optimizer/optimizer/adalite.py index cf43aa82..3249826c 100644 --- a/pytorch_optimizer/optimizer/adalite.py +++ b/pytorch_optimizer/optimizer/adalite.py @@ -41,7 +41,7 @@ def __init__( self.validate_betas(betas) self.validate_non_negative(weight_decay, 'weight_decay') self.validate_non_negative(eps1, 'eps1') - self.validate_non_negative(eps2, 'eps1') + self.validate_non_negative(eps2, 'eps2') defaults: DEFAULTS = { 'lr': lr, From 52434144892db429dca4fea2ec77b87347018793 Mon Sep 17 00:00:00 2001 From: kozistr Date: Wed, 4 Dec 2024 17:59:26 +0900 Subject: [PATCH 06/14] chore: keywords --- pyproject.toml | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index aef3b805..7e3eea6d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -14,13 +14,13 @@ keywords = [ "AdaDelta", "AdaFactor", "AdaMax", "AdamG", "AdaMod", "AdaNorm", "AdaPNM", "AdaSmooth", "AdEMAMix", "ADOPT", "AdaHessian", "Adai", "Adalite", "AdaLomo", "AdamMini", "AdamP", "AdamS", "Adan", "AggMo", "Aida", "AliG", "Amos", "Apollo", "AvaGrad", "bSAM", "CAME", "DAdaptAdaGrad", "DAdaptAdam", "DAdaptAdan", "DAdaptSGD", "DAdaptLion", - "DeMo", "DiffGrad", "FAdam", "Fromage", "FTRL", "GaLore", "Gravity", "GrokFast", "GSAM", "Kate", "Lamb", "LARS", - "Lion", "LOMO", "Lookahead", "MADGRAD", "MSVAG", "Muno", "Nero", "NovoGrad", "PAdam", "PCGrad", "PID", "PNM", - "Prodigy", "QHAdam", "QHM", "RAdam", "Ranger", "Ranger21", "RotoGrad", "SAM", "ScheduleFreeSGD", - "ScheduleFreeAdamW", "SGDP", "Shampoo", "ScalableShampoo", "SGDW", "SignSGD", "SM3", "SOAP", "SopihaH", "SRMM", - "StableAdamW", "SWATS", "Tiger", "TRAC", "WSAM", "Yogi", "BCE", "BCEFocal", "Focal", "FocalCosine", "SoftF1", - "Dice", "LDAM", "Jaccard", "Bi-Tempered", "Tversky", "FocalTversky", "LovaszHinge", "bitsandbytes", "WSD", - "QGaLore", + "DeMo", "DiffGrad", "FAdam", "Fromage", "FTRL", "GaLore", "Gravity", "GrokFast", "GSAM", "Kate", "Lamb", "LaProp", + "LARS", "Lion", "LOMO", "Lookahead", "MADGRAD", "MSVAG", "Muno", "Nero", "NovoGrad", "PAdam", "PCGrad", "PID", + "PNM", "Prodigy", "QHAdam", "QHM", "RAdam", "Ranger", "Ranger21", "RotoGrad", "SAM", "ScheduleFreeSGD", + "ScheduleFreeAdamW", "ScheduleFreeRAdam", "SGDP", "Shampoo", "ScalableShampoo", "SGDW", "SignSGD", "SM3", "SOAP", + "SopihaH", "SRMM", "StableAdamW", "SWATS", "Tiger", "TRAC", "WSAM", "Yogi", "BCE", "BCEFocal", "Focal", + "FocalCosine", "SoftF1", "Dice", "LDAM", "Jaccard", "Bi-Tempered", "Tversky", "FocalTversky", "LovaszHinge", + "bitsandbytes", "WSD", "QGaLore", ] classifiers = [ "License :: OSI Approved :: Apache Software License", From a979d4985ebbfce75686ac1311aa765406cb8df7 Mon Sep 17 00:00:00 2001 From: kozistr Date: Wed, 4 Dec 2024 17:59:37 +0900 Subject: [PATCH 07/14] feature: support Cautious variant --- pytorch_optimizer/optimizer/adamp.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/pytorch_optimizer/optimizer/adamp.py b/pytorch_optimizer/optimizer/adamp.py index ae9019b0..0971ffeb 100644 --- a/pytorch_optimizer/optimizer/adamp.py +++ b/pytorch_optimizer/optimizer/adamp.py @@ -22,6 +22,7 @@ class AdamP(BaseOptimizer): :param wd_ratio: float. relative weight decay applied on scale-invariant parameters compared to that applied on scale-variant parameters. :param use_gc: bool. use gradient centralization. + :param cautious: bool. whether to use the Cautious variant. :param nesterov: bool. enables Nesterov momentum. :param r: float. EMA factor. between 0.9 ~ 0.99 is preferred. :param adanorm: bool. whether to use the AdaNorm variant. @@ -40,6 +41,7 @@ def __init__( delta: float = 0.1, wd_ratio: float = 0.1, use_gc: bool = False, + cautious: bool = False, nesterov: bool = False, r: float = 0.95, adanorm: bool = False, @@ -54,6 +56,7 @@ def __init__( self.validate_non_negative(eps, 'eps') self.use_gc = use_gc + self.cautious = cautious defaults: DEFAULTS = { 'lr': lr, @@ -170,6 +173,9 @@ def step(self, closure: CLOSURE = None) -> LOSS: bias_correction1=bias_correction1, ) + if self.cautious: + self.apply_cautious(perturb, grad) + p.add_(perturb, alpha=-step_size) return loss From 0b6501b297d1fb6b108d429509879100606f11b8 Mon Sep 17 00:00:00 2001 From: kozistr Date: Wed, 4 Dec 2024 17:59:52 +0900 Subject: [PATCH 08/14] feature: implement ScheduleFreeRAdam optimizer --- pytorch_optimizer/optimizer/schedulefree.py | 166 ++++++++++++++++++++ 1 file changed, 166 insertions(+) diff --git a/pytorch_optimizer/optimizer/schedulefree.py b/pytorch_optimizer/optimizer/schedulefree.py index 1d31dd5a..bd8c2472 100644 --- a/pytorch_optimizer/optimizer/schedulefree.py +++ b/pytorch_optimizer/optimizer/schedulefree.py @@ -316,3 +316,169 @@ def step(self, closure: CLOSURE = None) -> LOSS: z.sub_(grad, alpha=lr) return loss + + +class ScheduleFreeRAdam(BaseOptimizer): + r"""Schedule-Free RAdam. + + :param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups. + :param lr: float. learning rate. + :param betas: BETAS. coefficients used for computing running averages of gradient and the squared hessian trace. + :param weight_decay: float. weight decay (L2 penalty). + :param weight_decouple: bool. the optimizer uses decoupled weight decay as in AdamW. + :param fixed_decay: bool. fix weight decay. + :param degenerated_to_sgd: float. degenerated to SGD. + :param r: float. use polynomial weighting in the average with power r. + :param weight_lr_power: float. during warmup, the weights in the average will be equal to lr raised to this power. + set to 0 for no weighting. + :param warmup_steps: int. enables a linear learning rate warmup. + :param ams_bound: bool. whether to use the AMSBound variant. + :param eps: float. term added to the denominator to improve numerical stability. + """ + + def __init__( + self, + params: PARAMETERS, + lr: float = 2.5e-3, + betas: BETAS = (0.9, 0.999), + weight_decay: float = 0.0, + weight_decouple: bool = True, + fixed_decay: bool = False, + degenerated_to_sgd: bool = False, + r: float = 0.0, + weight_lr_power: float = 2.0, + eps: float = 1e-8, + **kwargs, + ): + self.validate_learning_rate(lr) + self.validate_betas(betas) + self.validate_non_negative(weight_decay, 'weight_decay') + self.validate_non_negative(eps, 'eps') + + defaults: DEFAULTS = { + 'lr': lr, + 'betas': betas, + 'weight_decay': weight_decay, + 'weight_decouple': weight_decouple, + 'fixed_decay': fixed_decay, + 'degenerated_to_sgd': degenerated_to_sgd, + 'r': r, + 'weight_lr_power': weight_lr_power, + 'eps': eps, + 'train_mode': True, + 'weight_sum': 0.0, + 'lr_max': -1.0, + 'use_palm': kwargs.get('use_palm', False), + } + super().__init__(params, defaults) + + def __str__(self) -> str: + return 'ScheduleFreeRAdam' + + def eval(self): + for group in self.param_groups: + beta1, _ = group['betas'] + if group['train_mode']: + for p in group['params']: + state = self.state[p] + if 'z' in state: + p.data.lerp_(end=state['z'], weight=1.0 - 1.0 / beta1) + group['train_mode'] = False + + def train(self): + for group in self.param_groups: + beta1, _ = group['betas'] + if not group['train_mode']: + for p in group['params']: + state = self.state[p] + if 'z' in state: + p.data.lerp_(end=state['z'], weight=1.0 - beta1) + group['train_mode'] = True + + @torch.no_grad() + def reset(self): + for group in self.param_groups: + group['step'] = 0 + for p in group['params']: + state = self.state[p] + + state['z'] = p.clone() + state['exp_avg_sq'] = torch.zeros_like(p) + + @torch.no_grad() + def step(self, closure: CLOSURE = None) -> LOSS: + loss: LOSS = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + for group in self.param_groups: + if 'step' in group: + group['step'] += 1 + else: + group['step'] = 1 + + beta1, beta2 = group['betas'] + + bias_correction2_sq: float = math.sqrt(1.0 - beta2 ** group['step']) + + lr, n_sma = self.get_rectify_step_size( + is_rectify=True, + step=group['step'], + lr=group['lr'], + beta2=beta2, + n_sma_threshold=4, + degenerated_to_sgd=group['degenerated_to_sgd'], + ) + + lr_max = group['lr_max'] = max(lr, group['lr_max']) + + weight = (group['step'] ** group['r']) * (lr_max ** group['weight_lr_power']) + weight_sum = group['weight_sum'] = group['weight_sum'] + weight + + checkpoint: float = weight / weight_sum if weight_sum != 0.0 else 0.0 + + adaptive_y_lr: float = lr * (beta1 * (1.0 - checkpoint) - 1.0) + + if group['use_palm']: + beta2: float = 1.0 - group['step'] ** -0.8 + debias: float = (1.0 - beta2) / (1.0 - beta2 ** group['step']) + else: + debias: float = beta2 + + for p in group['params']: + if p.grad is None: + continue + + grad = p.grad + if grad.is_sparse: + raise NoSparseGradientError(str(self)) + + state = self.state[p] + + if len(state) == 0: + state['z'] = p.clone() + state['exp_avg_sq'] = torch.zeros_like(p) + + z, exp_avg_sq = state['z'], state['exp_avg_sq'] + exp_avg_sq.mul_(debias).addcmul_(grad, grad, value=1.0 - debias) + + if n_sma > 4.0: + de_nom = exp_avg_sq.sqrt().div_(bias_correction2_sq).add_(group['eps']) + grad.div_(de_nom) + + self.apply_weight_decay( + p=p, + grad=grad, + lr=lr, + weight_decay=group['weight_decay'], + weight_decouple=group['weight_decouple'], + fixed_decay=group['fixed_decay'], + ) + + p.lerp_(z, weight=checkpoint) + p.add_(grad, alpha=adaptive_y_lr) + + z.sub_(grad, alpha=lr) + + return loss From 50987e7b7be2288a46654c65032a0f9dc1069c32 Mon Sep 17 00:00:00 2001 From: kozistr Date: Wed, 4 Dec 2024 18:00:01 +0900 Subject: [PATCH 09/14] update: test cases --- tests/constants.py | 20 ++++++++++++++++---- tests/test_load_modules.py | 2 +- tests/test_optimizers.py | 26 ++++++++++++++++++++++++++ 3 files changed, 43 insertions(+), 5 deletions(-) diff --git a/tests/constants.py b/tests/constants.py index 06814951..96c495ac 100644 --- a/tests/constants.py +++ b/tests/constants.py @@ -57,6 +57,7 @@ GrokFastAdamW, Kate, Lamb, + LaProp, Lion, Muon, Nero, @@ -69,6 +70,7 @@ Ranger21, ScalableShampoo, ScheduleFreeAdamW, + ScheduleFreeRAdam, ScheduleFreeSGD, Shampoo, SignSGD, @@ -146,6 +148,7 @@ 'ademamix', 'soap', 'muon', + 'laprop', ] VALID_LR_SCHEDULER_NAMES: List[str] = [ @@ -204,7 +207,7 @@ (MADGRAD, {'lr': 5e-1, 'weight_decay': 1e-3, 'eps': 0.0}, 10), (MADGRAD, {'lr': 1e-1, 'weight_decay': 1e-3, 'momentum': 0.0}, 10), (MADGRAD, {'lr': 5e-1, 'weight_decay': 1e-3, 'weight_decouple': True}, 10), - (RAdam, {'lr': 5e-0, 'weight_decay': 1e-3}, 10), + (RAdam, {'lr': 5e0, 'weight_decay': 1e-3}, 10), (RAdam, {'lr': 5e-1, 'weight_decay': 1e-3, 'degenerated_to_sgd': True}, 5), (SGDP, {'lr': 5e-1, 'weight_decay': 1e-4}, 10), (SGDP, {'lr': 5e-1, 'weight_decay': 1e-4, 'nesterov': True}, 10), @@ -377,7 +380,6 @@ (AdamS, {'lr': 1e0, 'weight_decay': 1e-3, 'ams_bound': True}, 20), (AdaFactor, {'lr': 1e1, 'weight_decay': 1e-3, 'scale_parameter': False}, 100), (AdaFactor, {'lr': 1e1, 'weight_decay': 1e-3, 'ams_bound': True}, 120), - (AdaFactor, {'lr': 1e1, 'weight_decay': 1e-3, 'cautious': True}, 70), (AdaFactor, {'lr': 1e1, 'betas': (None, 0.999), 'weight_decay': 1e-3}, 40), (Apollo, {'lr': 5e-1, 'weight_decay': 1e-3}, 10), (Apollo, {'lr': 5e-1, 'weight_decay': 1e-3, 'rebound': 'belief'}, 10), @@ -386,7 +388,6 @@ (Lion, {'lr': 5e-1, 'weight_decay': 1e-3}, 5), (Lion, {'lr': 5e-1, 'weight_decay': 1e-3, 'weight_decouple': False}, 5), (Lion, {'lr': 5e-1, 'weight_decay': 1e-3, 'use_gc': True}, 10), - (Lion, {'lr': 5e-1, 'weight_decay': 1e-3, 'cautious': True}, 5), (AliG, {'max_lr': 5e-1, 'momentum': 0.9}, 5), (AliG, {'max_lr': 5e-1, 'momentum': 0.9, 'adjusted_momentum': True}, 5), (SM3, {'lr': 5e-1, 'momentum': 0.9, 'beta': 0.9}, 5), @@ -482,6 +483,8 @@ (ScheduleFreeSGD, {'lr': 1e0, 'weight_decay': 1e-3}, 5), (ScheduleFreeAdamW, {'lr': 1e0, 'weight_decay': 1e-3}, 5), (ScheduleFreeAdamW, {'lr': 1e-2, 'weight_decay': 1e-3, 'use_palm': True}, 5), + (ScheduleFreeRAdam, {'lr': 1e0, 'weight_decay': 1e-3, 'degenerated_to_sgd': True}, 5), + (ScheduleFreeRAdam, {'lr': 1e0, 'weight_decay': 1e-3, 'use_palm': True, 'degenerated_to_sgd': True}, 5), (FAdam, {'lr': 1e0, 'weight_decay': 1e-3}, 5), (GrokFastAdamW, {'lr': 5e0, 'weight_decay': 1e-3, 'grokfast_after_step': 1}, 5), (Kate, {'lr': 5e-2}, 10), @@ -489,7 +492,6 @@ (AdamG, {'lr': 1e0}, 20), (AdEMAMix, {'lr': 1e0}, 3), (AdEMAMix, {'lr': 1e0, 't_alpha_beta3': 5}, 3), - (AdEMAMix, {'lr': 1e0, 'cautious': True}, 2), ( SOAP, {'lr': 1e0, 'shampoo_beta': 0.95, 'precondition_frequency': 1, 'merge_dims': False, 'precondition_1d': True}, @@ -499,6 +501,9 @@ (FTRL, {'lr': 1e0, 'beta': 0.0, 'lambda_1': 0.0, 'lambda_2': 0.0}, 5), (Muon, {'lr': 1e0, 'ns_steps': 6, 'adam_lr': 1e0, 'adamw_wd': 1e-2}, 5), (Muon, {'lr': 1e0, 'ns_steps': 6, 'adam_lr': 1e0, 'adamw_wd': 1e-2, 'nesterov': False}, 5), + (LaProp, {'lr': 1e0, 'weight_decay': 1e-3}, 5), + (LaProp, {'lr': 1e0, 'centered': True, 'weight_decay': 1e-3}, 11), + (LaProp, {'lr': 1e0, 'ams_bound': True, 'weight_decay': 1e-3}, 5), ] ADANORM_SUPPORTED_OPTIMIZERS: List[Tuple[Any, Dict[str, Union[float, bool, int]], int]] = [ (AdaBelief, {'lr': 5e-1, 'weight_decay': 1e-3, 'adanorm': True}, 10), @@ -540,3 +545,10 @@ (AdaHessian, {'lr': 5e0, 'weight_decay': 1e-3, 'adam_debias': True}, 5), (Aida, {'lr': 1e1, 'weight_decay': 1e-3, 'rectify': True, 'adam_debias': True}, 10), ] +COPT_SUPPORTED_OPTIMIZERS: List[Tuple[Any, Dict[str, Union[float, bool, int]], int]] = [ + (AdaFactor, {'lr': 1e1, 'weight_decay': 1e-3, 'cautious': True}, 70), + (Lion, {'lr': 5e-1, 'weight_decay': 1e-3, 'cautious': True}, 5), + (AdEMAMix, {'lr': 1e0, 'cautious': True}, 2), + (LaProp, {'lr': 1e0, 'cautious': True}, 2), + (AdamP, {'lr': 1e0, 'cautious': True}, 2), +] diff --git a/tests/test_load_modules.py b/tests/test_load_modules.py index 06fbe93c..ccca7ef5 100644 --- a/tests/test_load_modules.py +++ b/tests/test_load_modules.py @@ -34,7 +34,7 @@ def test_load_lr_scheduler_invalid(invalid_lr_scheduler_names): def test_get_supported_optimizers(): - assert len(get_supported_optimizers()) == 80 + assert len(get_supported_optimizers()) == 82 assert len(get_supported_optimizers('adam*')) == 7 assert len(get_supported_optimizers(['adam*', 'ranger*'])) == 9 diff --git a/tests/test_optimizers.py b/tests/test_optimizers.py index effcd2d2..08ce607c 100644 --- a/tests/test_optimizers.py +++ b/tests/test_optimizers.py @@ -26,6 +26,7 @@ ADAMD_SUPPORTED_OPTIMIZERS, ADANORM_SUPPORTED_OPTIMIZERS, ADAPTIVE_FLAGS, + COPT_SUPPORTED_OPTIMIZERS, DECOUPLE_FLAGS, OPTIMIZERS, PULLBACK_MOMENTUM, @@ -335,6 +336,31 @@ def test_adamd_optimizers(optimizer_config, environment): assert tensor_to_numpy(init_loss) > 2.0 * tensor_to_numpy(loss) +@pytest.mark.parametrize('optimizer_config', COPT_SUPPORTED_OPTIMIZERS, ids=ids) +def test_copt_optimizers(optimizer_config, environment): + (x_data, y_data), model, loss_fn = environment + + optimizer_class, config, num_iterations = optimizer_config + + optimizer = optimizer_class(model.parameters(), **config) + + init_loss, loss = np.inf, np.inf + for _ in range(num_iterations): + optimizer.zero_grad() + + y_pred = model(x_data) + loss = loss_fn(y_pred, y_data) + + if init_loss == np.inf: + init_loss = loss + + loss.backward() + + optimizer.step() + + assert tensor_to_numpy(init_loss) > 1.5 * tensor_to_numpy(loss) + + @pytest.mark.parametrize('reduction', ['mean', 'sum']) def test_pc_grad_optimizers(reduction): x_data, y_data = make_dataset() From b573013e9c9683f1705d6d28bb86db34b98fbc8a Mon Sep 17 00:00:00 2001 From: kozistr Date: Wed, 4 Dec 2024 18:00:09 +0900 Subject: [PATCH 10/14] update: optimizers --- pytorch_optimizer/__init__.py | 2 ++ pytorch_optimizer/optimizer/__init__.py | 5 ++++- 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/pytorch_optimizer/__init__.py b/pytorch_optimizer/__init__.py index 9f6b3c30..8af7e027 100644 --- a/pytorch_optimizer/__init__.py +++ b/pytorch_optimizer/__init__.py @@ -107,6 +107,7 @@ GrokFastAdamW, Kate, Lamb, + LaProp, Lion, Lookahead, Muon, @@ -123,6 +124,7 @@ SafeFP16Optimizer, ScalableShampoo, ScheduleFreeAdamW, + ScheduleFreeRAdam, ScheduleFreeSGD, Shampoo, SignSGD, diff --git a/pytorch_optimizer/optimizer/__init__.py b/pytorch_optimizer/optimizer/__init__.py index 4125a058..23397307 100644 --- a/pytorch_optimizer/optimizer/__init__.py +++ b/pytorch_optimizer/optimizer/__init__.py @@ -50,6 +50,7 @@ from pytorch_optimizer.optimizer.grokfast import GrokFastAdamW from pytorch_optimizer.optimizer.kate import Kate from pytorch_optimizer.optimizer.lamb import Lamb +from pytorch_optimizer.optimizer.laprop import LaProp from pytorch_optimizer.optimizer.lars import LARS from pytorch_optimizer.optimizer.lion import Lion from pytorch_optimizer.optimizer.lomo import LOMO, AdaLOMO @@ -71,7 +72,7 @@ from pytorch_optimizer.optimizer.ranger21 import Ranger21 from pytorch_optimizer.optimizer.rotograd import RotoGrad from pytorch_optimizer.optimizer.sam import BSAM, GSAM, SAM, WSAM -from pytorch_optimizer.optimizer.schedulefree import ScheduleFreeAdamW, ScheduleFreeSGD +from pytorch_optimizer.optimizer.schedulefree import ScheduleFreeAdamW, ScheduleFreeRAdam, ScheduleFreeSGD from pytorch_optimizer.optimizer.sgd import ASGD, SGDW, AccSGD, SignSGD from pytorch_optimizer.optimizer.sgdp import SGDP from pytorch_optimizer.optimizer.shampoo import ScalableShampoo, Shampoo @@ -275,6 +276,8 @@ def load_optimizer(optimizer: str) -> OPTIMIZER: FTRL, DeMo, Muon, + ScheduleFreeRAdam, + LaProp, ] OPTIMIZERS: Dict[str, OPTIMIZER] = {str(optimizer.__name__).lower(): optimizer for optimizer in OPTIMIZER_LIST} From a6287497f81707c6f6f4d4e93231b2945e7007e5 Mon Sep 17 00:00:00 2001 From: kozistr Date: Wed, 4 Dec 2024 18:05:19 +0900 Subject: [PATCH 11/14] feature: support cautious variant --- pytorch_optimizer/optimizer/adopt.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/pytorch_optimizer/optimizer/adopt.py b/pytorch_optimizer/optimizer/adopt.py index 9e95bb22..71d23ea8 100644 --- a/pytorch_optimizer/optimizer/adopt.py +++ b/pytorch_optimizer/optimizer/adopt.py @@ -17,6 +17,7 @@ class ADOPT(BaseOptimizer): :param weight_decay: float. weight decay (L2 penalty). :param weight_decouple: bool. the optimizer uses decoupled weight decay as in AdamW. :param fixed_decay: bool. fix weight decay. + :param cautious: bool. whether to use the Cautious variant. :param eps: float. term added to the denominator to improve numerical stability. """ @@ -29,6 +30,7 @@ def __init__( weight_decay: float = 0.0, weight_decouple: bool = False, fixed_decay: bool = False, + cautious: bool = False, eps: float = 1e-6, **kwargs, ): @@ -38,6 +40,7 @@ def __init__( self.validate_non_negative(eps, 'eps') self.clip_lambda = clip_lambda + self.cautious = cautious defaults: DEFAULTS = { 'lr': lr, @@ -118,6 +121,12 @@ def step(self, closure: CLOSURE = None) -> LOSS: exp_avg.lerp_(normed_grad, weight=1.0 - beta1) - p.add_(exp_avg, alpha=-group['lr']) + if self.cautious: + update = exp_avg.clone() + self.apply_cautious(update, normed_grad) + else: + update = exp_avg + + p.add_(update, alpha=-group['lr']) return loss From e89c146c8fdbfa51d5bfc27ca6c99ae33e924fc6 Mon Sep 17 00:00:00 2001 From: kozistr Date: Wed, 4 Dec 2024 18:05:26 +0900 Subject: [PATCH 12/14] update: test cases --- tests/constants.py | 1 + tests/test_optimizers.py | 6 +++--- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/tests/constants.py b/tests/constants.py index 96c495ac..ec397736 100644 --- a/tests/constants.py +++ b/tests/constants.py @@ -551,4 +551,5 @@ (AdEMAMix, {'lr': 1e0, 'cautious': True}, 2), (LaProp, {'lr': 1e0, 'cautious': True}, 2), (AdamP, {'lr': 1e0, 'cautious': True}, 2), + (ADOPT, {'lr': 1e1, 'cautious': True}, 3), ] diff --git a/tests/test_optimizers.py b/tests/test_optimizers.py index 08ce607c..80e7b812 100644 --- a/tests/test_optimizers.py +++ b/tests/test_optimizers.py @@ -298,7 +298,7 @@ def test_adanorm_optimizer(optimizer_config, environment): @pytest.mark.parametrize('optimizer_config', ADANORM_SUPPORTED_OPTIMIZERS, ids=ids) -def test_adanorm_condition(optimizer_config): +def test_adanorm_variant(optimizer_config): param = simple_parameter(True) param.grad = torch.ones(1, 1) @@ -312,7 +312,7 @@ def test_adanorm_condition(optimizer_config): @pytest.mark.parametrize('optimizer_config', ADAMD_SUPPORTED_OPTIMIZERS, ids=ids) -def test_adamd_optimizers(optimizer_config, environment): +def test_adamd_variant(optimizer_config, environment): (x_data, y_data), model, loss_fn = environment optimizer_class, config, num_iterations = optimizer_config @@ -337,7 +337,7 @@ def test_adamd_optimizers(optimizer_config, environment): @pytest.mark.parametrize('optimizer_config', COPT_SUPPORTED_OPTIMIZERS, ids=ids) -def test_copt_optimizers(optimizer_config, environment): +def test_cautious_variant(optimizer_config, environment): (x_data, y_data), model, loss_fn = environment optimizer_class, config, num_iterations = optimizer_config From e1ab49387e5dc18574f7c39888b6295754ca5483 Mon Sep 17 00:00:00 2001 From: kozistr Date: Wed, 4 Dec 2024 18:05:34 +0900 Subject: [PATCH 13/14] docs: v3.3.1 changelog --- docs/changelogs/v3.3.1.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/changelogs/v3.3.1.md b/docs/changelogs/v3.3.1.md index eab3c666..2a366cdf 100644 --- a/docs/changelogs/v3.3.1.md +++ b/docs/changelogs/v3.3.1.md @@ -9,4 +9,4 @@ * Implement `ScheduleFreeRAdam` optimizer. (#304) * Implement `LaProp` optimizer. (#304) * [Separating Momentum and Adaptivity in Adam](https://arxiv.org/abs/2002.04839) -* Support `Cautious` variant to `LaProp`, `AdamP`. (#304). +* Support `Cautious` variant to `LaProp`, `AdamP`, `Adopt` optimizers. (#304). From 53264836564487e797eb11683c9fcb5af42facde Mon Sep 17 00:00:00 2001 From: kozistr Date: Wed, 4 Dec 2024 18:11:24 +0900 Subject: [PATCH 14/14] update: test_schedule_free_methods --- tests/test_optimizers.py | 18 ++++++------------ 1 file changed, 6 insertions(+), 12 deletions(-) diff --git a/tests/test_optimizers.py b/tests/test_optimizers.py index 80e7b812..2d24859c 100644 --- a/tests/test_optimizers.py +++ b/tests/test_optimizers.py @@ -639,18 +639,12 @@ def test_dynamic_scaler(): scaler.update_scale(overflow=False) -def test_schedule_free_train_mode(): - param = simple_parameter(True) - - opt = load_optimizer('ScheduleFreeAdamW')([param]) - opt.reset() - opt.eval() - opt.train() - - opt = load_optimizer('ScheduleFreeSGD')([param]) - opt.reset() - opt.eval() - opt.train() +@pytest.mark.parametrize('optimizer_name', ['ScheduleFreeAdamW', 'ScheduleFreeSGD', 'ScheduleFreeRAdam']) +def test_schedule_free_methods(optimizer_name): + optimizer = load_optimizer(optimizer_name)([simple_parameter(True)]) + optimizer.reset() + optimizer.eval() + optimizer.train() @pytest.mark.parametrize('filter_type', ['mean', 'sum'])