Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Implement FTRL optimizer #291

Merged
merged 15 commits into from
Nov 23, 2024
167 changes: 84 additions & 83 deletions README.md

Large diffs are not rendered by default.

2 changes: 2 additions & 0 deletions docs/changelogs/v3.3.0.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,5 @@
* you can use this feature by setting `use_palm` to `True`.
* Implement `ADOPT` optimizer. (#289, #290)
* [Modified Adam Can Converge with Any β2 with the Optimal Rate](https://arxiv.org/abs/2411.02853)
* Implement `FTRL` optimizer. (#291)
* [Follow The Regularized Leader](https://static.googleusercontent.com/media/research.google.com/en//pubs/archive/41159.pdf)
167 changes: 84 additions & 83 deletions docs/index.md

Large diffs are not rendered by default.

4 changes: 4 additions & 0 deletions docs/optimizer.md
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,10 @@
:docstring:
:members:

::: pytorch_optimizer.FTRL
:docstring:
:members:

::: pytorch_optimizer.GaLoreProjector
:docstring:
:members:
Expand Down
60 changes: 30 additions & 30 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

6 changes: 3 additions & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,9 @@ keywords = [
"AdaDelta", "AdaFactor", "AdaMax", "AdamG", "AdaMod", "AdaNorm", "AdaPNM", "AdaSmooth", "AdEMAMix", "ADOPT",
"AdaHessian", "Adai", "Adalite", "AdaLomo", "AdamMini", "AdamP", "AdamS", "Adan", "AggMo", "Aida", "AliG", "Amos",
"Apollo", "AvaGrad", "bSAM", "CAME", "DAdaptAdaGrad", "DAdaptAdam", "DAdaptAdan", "DAdaptSGD", "DAdaptLion",
"DiffGrad", "FAdam", "Fromage", "GaLore", "Gravity", "GrokFast", "GSAM", "Kate", "Lamb", "LARS", "Lion", "LOMO",
"Lookahead", "MADGRAD", "MSVAG", "Nero", "NovoGrad", "PAdam", "PCGrad", "PID", "PNM", "Prodigy", "QHAdam", "QHM",
"RAdam", "Ranger", "Ranger21", "RotoGrad", "SAM", "ScheduleFreeSGD", "ScheduleFreeAdamW", "SGDP", "Shampoo",
"DiffGrad", "FAdam", "Fromage", "FTRL", "GaLore", "Gravity", "GrokFast", "GSAM", "Kate", "Lamb", "LARS", "Lion",
"LOMO", "Lookahead", "MADGRAD", "MSVAG", "Nero", "NovoGrad", "PAdam", "PCGrad", "PID", "PNM", "Prodigy", "QHAdam",
"QHM", "RAdam", "Ranger", "Ranger21", "RotoGrad", "SAM", "ScheduleFreeSGD", "ScheduleFreeAdamW", "SGDP", "Shampoo",
"ScalableShampoo", "SGDW", "SignSGD", "SM3", "SOAP", "SopihaH", "SRMM", "StableAdamW", "SWATS", "Tiger", "TRAC",
"WSAM", "Yogi", "BCE", "BCEFocal", "Focal", "FocalCosine", "SoftF1", "Dice", "LDAM", "Jaccard", "Bi-Tempered",
"Tversky", "FocalTversky", "LovaszHinge", "bitsandbytes", "WSD", "QGaLore",
Expand Down
2 changes: 2 additions & 0 deletions pytorch_optimizer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@
from pytorch_optimizer.optimizer.fadam import FAdam
from pytorch_optimizer.optimizer.fp16 import DynamicLossScaler, SafeFP16Optimizer
from pytorch_optimizer.optimizer.fromage import Fromage
from pytorch_optimizer.optimizer.ftrl import FTRL
from pytorch_optimizer.optimizer.galore import GaLore, GaLoreProjector
from pytorch_optimizer.optimizer.gc import centralize_gradient
from pytorch_optimizer.optimizer.gravity import Gravity
Expand Down Expand Up @@ -217,6 +218,7 @@
AdEMAMix,
SOAP,
ADOPT,
FTRL,
]
OPTIMIZERS: Dict[str, OPTIMIZER] = {str(optimizer.__name__).lower(): optimizer for optimizer in OPTIMIZER_LIST}

Expand Down
5 changes: 5 additions & 0 deletions pytorch_optimizer/base/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,6 +271,11 @@ def validate_non_negative(x: Optional[float], name: str) -> None:
if x is not None and x < 0.0:
raise ValueError(f'[-] {name} must be non-negative')

@staticmethod
def validate_non_positive(x: Optional[float], name: str) -> None:
if x is not None and x > 0.0:
raise ValueError(f'[-] {name} must be non-positive')

@staticmethod
def validate_positive(x: Union[float, int], name: str) -> None:
if x <= 0:
Expand Down
88 changes: 88 additions & 0 deletions pytorch_optimizer/optimizer/ftrl.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
import torch

from pytorch_optimizer.base.exception import NoSparseGradientError
from pytorch_optimizer.base.optimizer import BaseOptimizer
from pytorch_optimizer.base.types import CLOSURE, DEFAULTS, LOSS, PARAMETERS


class FTRL(BaseOptimizer):
r"""Follow The Regularized Leader.

:param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.
:param lr: float. learning rate.
:param lr_power: float. controls how the learning rate decreases during training. use zero for a fixed learning
rate.
:param beta: float. beta value in the paper.
:param lambda_1: float. L1 regularization parameter.
:param lambda_2: float. L2 regularization parameter.
"""

def __init__(
self,
params: PARAMETERS,
lr: float = 1e-3,
lr_power: float = -0.5,
beta: float = 0.0,
lambda_1: float = 0.0,
lambda_2: float = 0.0,
**kwargs
):
self.validate_learning_rate(lr)
self.validate_non_negative(beta, 'beta')
self.validate_non_positive(lr_power, 'lr_power')
self.validate_non_negative(lambda_1, 'lambda_1')
self.validate_non_negative(lambda_2, 'lambda_2')

defaults: DEFAULTS = {'lr': lr, 'lr_power': lr_power, 'beta': beta, 'lambda_1': lambda_1, 'lambda_2': lambda_2}
super().__init__(params, defaults)

def __str__(self) -> str:
return 'FTRL'

@torch.no_grad()
def reset(self):
for group in self.param_groups:
for p in group['params']:
state = self.state[p]

state['z'] = torch.zeros_like(p)
state['n'] = torch.zeros_like(p)

@torch.no_grad()
def step(self, closure: CLOSURE = None) -> LOSS:
loss: LOSS = None
if closure is not None:
with torch.enable_grad():
loss = closure()

for group in self.param_groups:
for p in group['params']:
if p.grad is None:
continue

grad = p.grad
if grad.is_sparse:
raise NoSparseGradientError(str(self))

state = self.state[p]

if len(state) == 0:
state['z'] = torch.zeros_like(p)
state['n'] = torch.zeros_like(p)

z, n = state['z'], state['n']

grad_p2 = grad.pow(2)

sigma = (n + grad_p2).pow_(-group['lr_power']).sub_(n.pow(-group['lr_power'])).div_(group['lr'])

z.add_(grad).sub_(sigma.mul(p))
n.add_(grad_p2)

update = z.sign().mul_(group['lambda_1']).sub_(z)
update.div_((group['beta'] + n.sqrt()).div_(group['lr']).add_(group['lambda_2']))

p.copy_(update)
p.masked_fill_(z.abs() < group['lambda_1'], 0.0)

return loss
4 changes: 2 additions & 2 deletions pytorch_optimizer/optimizer/lomo.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ def func(x: Any) -> Any:
if not p.requires_grad or p.grad is None:
continue

if self.loss_scaler and self.loss_scaler.has_overflow_serial or has_overflow(p.grad):
if (self.loss_scaler and self.loss_scaler.has_overflow_serial) or has_overflow(p.grad):
p.grad = None
self.loss_scaler.has_overflow_serial = True
break
Expand Down Expand Up @@ -119,7 +119,7 @@ def func(x: torch.Tensor) -> torch.Tensor:

all_reduce(p.grad, op=ReduceOp.AVG, async_op=False)

if self.loss_scaler and self.loss_scaler.has_overflow_serial or has_overflow(p.grad):
if (self.loss_scaler and self.loss_scaler.has_overflow_serial) or has_overflow(p.grad):
p.grad = None
self.loss_scaler.has_overflow_serial = True
break
Expand Down
6 changes: 3 additions & 3 deletions requirements-dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,10 @@ platformdirs==4.3.6 ; python_version >= "3.8"
pluggy==1.5.0 ; python_version >= "3.8"
pytest-cov==5.0.0 ; python_version >= "3.8"
pytest==8.3.3 ; python_version >= "3.8"
ruff==0.7.3 ; python_version >= "3.8"
setuptools==75.3.0 ; python_version >= "3.12"
ruff==0.8.0 ; python_version >= "3.8"
setuptools==75.6.0 ; python_version >= "3.12"
sympy==1.12.1 ; python_version == "3.8"
sympy==1.13.1 ; python_version >= "3.9"
tomli==2.0.2 ; python_full_version <= "3.11.0a6" and python_version >= "3.8"
tomli==2.1.0 ; python_full_version <= "3.11.0a6" and python_version >= "3.8"
torch==2.5.1+cpu ; python_version >= "3.8"
typing-extensions==4.12.2 ; python_version >= "3.8"
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ mpmath==1.3.0 ; python_version >= "3.9" or python_version == "3.8"
networkx==3.1 ; python_version >= "3.8"
numpy==1.24.4 ; python_version < "3.9" and python_version >= "3.8"
numpy==2.0.2 ; python_version >= "3.9"
setuptools==75.3.0 ; python_version >= "3.12"
setuptools==75.6.0 ; python_version >= "3.12"
sympy==1.12.1 ; python_version == "3.8"
sympy==1.13.1 ; python_version >= "3.9"
torch==2.5.1+cpu ; python_version >= "3.8"
Expand Down
2 changes: 2 additions & 0 deletions tests/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
ADOPT,
ASGD,
CAME,
FTRL,
LARS,
MADGRAD,
MSVAG,
Expand Down Expand Up @@ -485,6 +486,7 @@
3,
),
(ADOPT, {'lr': 1e0}, 5),
(FTRL, {'lr': 1e0, 'beta': 0.0, 'lambda_1': 0.0, 'lambda_2': 0.0}, 5),
]
ADANORM_SUPPORTED_OPTIMIZERS: List[Tuple[Any, Dict[str, Union[float, bool, int]], int]] = [
(AdaBelief, {'lr': 5e-1, 'weight_decay': 1e-3, 'adanorm': True}, 10),
Expand Down
5 changes: 5 additions & 0 deletions tests/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,11 @@ def test_validate_range(range_type):
BaseOptimizer.validate_range(-1.0, 'x', 0.0, 1.0, range_type=range_type)


def test_non_positive():
with pytest.raises(ValueError):
BaseOptimizer.validate_non_positive(1.0, 'asdf')


def test_mod():
with pytest.raises(ValueError):
BaseOptimizer.validate_mod(10, 3)
2 changes: 2 additions & 0 deletions tests/test_general_optimizer_parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ def test_epsilon(optimizer_name):
'adalite',
'bsam',
'adalomo',
'ftrl',
):
pytest.skip(f'skip {optimizer_name} optimizer')

Expand Down Expand Up @@ -78,6 +79,7 @@ def test_weight_decay(optimizer_name):
'adashift',
'amos',
'lomo',
'ftrl',
):
pytest.skip(f'skip {optimizer_name} optimizer')

Expand Down
2 changes: 1 addition & 1 deletion tests/test_load_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def test_load_lr_scheduler_invalid(invalid_lr_scheduler_names):


def test_get_supported_optimizers():
assert len(get_supported_optimizers()) == 77
assert len(get_supported_optimizers()) == 78
assert len(get_supported_optimizers('adam*')) == 7
assert len(get_supported_optimizers(['adam*', 'ranger*'])) == 9

Expand Down