Skip to content

Commit

Permalink
Merge pull request #263 from kozistr/feature/trac-optimizer
Browse files Browse the repository at this point in the history
[Feature] Implement TRAC optimizer
  • Loading branch information
kozistr authored Aug 7, 2024
2 parents d00136f + e94290f commit 1054960
Show file tree
Hide file tree
Showing 15 changed files with 483 additions and 169 deletions.
157 changes: 79 additions & 78 deletions README.md

Large diffs are not rendered by default.

11 changes: 11 additions & 0 deletions docs/changelogs/v3.1.1.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
## Change Log

### Feature

* Implement `TRAC` optimizer. (#263)
* [Fast TRAC: A Parameter-Free Optimizer for Lifelong Reinforcement Learning](https://arxiv.org/abs/2405.16642)
* Support `AdamW` optimizer via `create_optimizer()`. (#263)

### Bug

* Fix to handle the optimizers that only take the `model` instead of the parameters in `create_optimizer()`. (#263)
157 changes: 79 additions & 78 deletions docs/index.md

Large diffs are not rendered by default.

4 changes: 4 additions & 0 deletions docs/optimizer.md
Original file line number Diff line number Diff line change
Expand Up @@ -328,6 +328,10 @@
:docstring:
:members:

::: pytorch_optimizer.TRAC
:docstring:
:members:

::: pytorch_optimizer.WSAM
:docstring:
:members:
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ keywords = [
"GaLore", "Gravity", "GrokFast", "GSAM", "Kate", "Lamb", "LARS", "Lion", "LOMO", "Lookahead", "MADGRAD", "MSVAG",
"Nero", "NovoGrad", "PAdam", "PCGrad", "PID", "PNM", "Prodigy", "QHAdam", "QHM", "RAdam", "Ranger", "Ranger21",
"RotoGrad", "SAM", "ScheduleFreeSGD", "ScheduleFreeAdamW", "SGDP", "Shampoo", "ScalableShampoo", "SGDW", "SignSGD",
"SM3", "SopihaH", "SRMM", "StableAdamW", "SWATS", "Tiger", "WSAM", "Yogi", "BCE", "BCEFocal", "Focal",
"SM3", "SopihaH", "SRMM", "StableAdamW", "SWATS", "Tiger", "TRAC", "WSAM", "Yogi", "BCE", "BCEFocal", "Focal",
"FocalCosine", "SoftF1", "Dice", "LDAM", "Jaccard", "Bi-Tempered", "Tversky", "FocalTversky", "LovaszHinge",
"bitsandbytes", "WSD", "QGaLore",
]
Expand Down
5 changes: 5 additions & 0 deletions pytorch_optimizer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import torch.cuda
from torch import nn
from torch.optim import AdamW

from pytorch_optimizer.base.types import OPTIMIZER, PARAMETERS, SCHEDULER
from pytorch_optimizer.loss.bi_tempered import BinaryBiTemperedLogisticLoss, BiTemperedLogisticLoss
Expand Down Expand Up @@ -115,6 +116,7 @@
from pytorch_optimizer.optimizer.srmm import SRMM
from pytorch_optimizer.optimizer.swats import SWATS
from pytorch_optimizer.optimizer.tiger import Tiger
from pytorch_optimizer.optimizer.trac import TRAC
from pytorch_optimizer.optimizer.utils import (
clip_grad_norm,
disable_running_stats,
Expand All @@ -131,6 +133,7 @@
HAS_Q_GALORE: bool = find_spec('q-galore-torch') is not None

OPTIMIZER_LIST: List[OPTIMIZER] = [
AdamW,
AdaBelief,
AdaBound,
PID,
Expand Down Expand Up @@ -350,6 +353,8 @@ def create_optimizer(

if optimizer_name == 'alig':
optimizer = optimizer(parameters, max_lr=lr, **kwargs)
elif optimizer_name in {'lomo', 'adalomo', 'adammini'}:
optimizer = optimizer(model, lr=lr, **kwargs)
else:
optimizer = optimizer(parameters, lr=lr, **kwargs)

Expand Down
2 changes: 1 addition & 1 deletion pytorch_optimizer/optimizer/lookahead.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def __init__(
k: int = 5,
alpha: float = 0.5,
pullback_momentum: str = 'none',
):
) -> None:
self.validate_positive(k, 'k')
self.validate_range(alpha, 'alpha', 0.0, 1.0)
self.validate_options(pullback_momentum, 'pullback_momentum', ['none', 'reset', 'pullback'])
Expand Down
253 changes: 253 additions & 0 deletions pytorch_optimizer/optimizer/trac.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,253 @@
from typing import Callable, Dict, List, Tuple

import torch
from torch import nn

from pytorch_optimizer.base.optimizer import BaseOptimizer
from pytorch_optimizer.base.types import CLOSURE, DEFAULTS, LOSS, OPTIMIZER


def polyval(x: torch.Tensor, coef: torch.Tensor) -> torch.Tensor:
r"""Implement of the Horner scheme to evaluate a polynomial.
taken from https://discuss.pytorch.org/t/polynomial-evaluation-by-horner-rule/67124
:param x: torch.Tensor. variable.
:param coef: torch.Tensor. coefficients of the polynomial.
"""
result = coef[0].clone()

for c in coef[1:]:
result = (result * x) + c

return result[0]


class ERF1994(nn.Module):
r"""Implementation of ERF1994.
:param num_coefs: int. The number of polynomial coefficients to use in the approximation.
"""

def __init__(self, num_coefs: int = 128) -> None:
super().__init__()

self.n: int = num_coefs

self.i: torch.Tensor = torch.complex(torch.tensor(0.0), torch.tensor(1.0))
self.m = 2 * self.n
self.m2 = 2 * self.m
self.k = torch.linspace(-self.m + 1, self.m - 1, self.m2 - 1)
self.l = torch.sqrt(self.n / torch.sqrt(torch.tensor(2.0)))
self.theta = self.k * torch.pi / self.m
self.t = self.l * torch.tan(self.theta / 2.0)
self.f = torch.exp(-self.t ** 2) * (self.l ** 2 + self.t ** 2) # fmt: skip
self.a = torch.fft.fft(torch.fft.fftshift(self.f)).real / self.m2
self.a = torch.flipud(self.a[1:self.n + 1]) # fmt: skip

def w_algorithm(self, z: torch.Tensor) -> torch.Tensor:
r"""Compute the Faddeeva function of a complex number.
:param z: torch.Tensor. A tensor of complex numbers.
"""
self.l = self.l.to(z.device)
self.i = self.i.to(z.device)
self.a = self.a.to(z.device)

iz = self.i * z
lp_iz, ln_iz = self.l + iz, self.l - iz

z_ = lp_iz / ln_iz
p = polyval(z_.unsqueeze(0), self.a)
return 2 * p / ln_iz.pow(2) + (1.0 / torch.sqrt(torch.tensor(torch.pi))) / ln_iz

def forward(self, z: torch.Tensor) -> torch.Tensor:
r"""Compute the error function of a complex number.
:param z: torch.Tensor. A tensor of complex numbers.
"""
sign_r = torch.sign(z.real)
sign_i = torch.sign(z.imag)
z = torch.complex(torch.abs(z.real), torch.abs(z.imag))
out = -torch.exp(torch.log(self.w_algorithm(z * self.i)) - z ** 2) + 1 # fmt: skip
return torch.complex(out.real * sign_r, out.imag * sign_i)


class TRAC(BaseOptimizer):
r"""A Parameter-Free Optimizer for Lifelong Reinforcement Learning.
Example:
-------
Here's an example::
model = YourModel()
optimizer = TRAC(AdamW(model.parameters()))
for input, output in data:
optimizer.zero_grad()
loss = loss_fn(model(input), output)
loss.backward()
optimizer.step()
:param optimizer: Optimizer. base optimizer.
:param betas: List[float]. list of beta values.
:param num_coefs: int. the number of polynomial coefficients to use in the approximation.
:param s_prev: float. initial scale value.
:param eps: float. term added to the denominator to improve numerical stability.
"""

def __init__(
self,
optimizer: OPTIMIZER,
betas: List[float] = (0.9, 0.99, 0.999, 0.9999, 0.99999, 0.999999),
num_coefs: int = 128,
s_prev: float = 1e-8,
eps: float = 1e-8,
):
self.validate_positive(num_coefs, 'num_coefs')
self.validate_non_negative(s_prev, 's_prev')
self.validate_non_negative(eps, 'eps')

self._optimizer_step_pre_hooks: Dict[int, Callable] = {}
self._optimizer_step_post_hooks: Dict[int, Callable] = {}

self.erf = ERF1994(num_coefs=num_coefs)
self.betas = betas
self.s_prev = s_prev
self.eps = eps

self.f_term = self.s_prev / self.erf_imag(1.0 / torch.sqrt(torch.tensor(2.0)))

self.optimizer = optimizer
self.defaults: DEFAULTS = optimizer.defaults

def __str__(self) -> str:
return 'TRAC'

@property
def param_groups(self):
return self.optimizer.param_groups

@property
def state(self):
return self.optimizer.state

@torch.no_grad()
def reset(self):
device = self.param_groups[0]['params'][0].device

self.state['trac'] = {
'betas': torch.tensor(self.betas, device=device),
's': torch.zeros(len(self.betas), device=device),
'variance': torch.zeros(len(self.betas), device=device),
'sigma': torch.full((len(self.betas),), 1e-8, device=device),
'step': 0,
}

for group in self.param_groups:
for p in group['params']:
self.state['trac'][p] = p.clone()

@torch.no_grad()
def zero_grad(self) -> None:
self.optimizer.zero_grad(set_to_none=True)

@torch.no_grad()
def erf_imag(self, x: torch.Tensor) -> torch.Tensor:
if not torch.is_floating_point(x):
x = x.to(torch.float32)

ix = torch.complex(torch.zeros_like(x), x)

return self.erf(ix).imag

@torch.no_grad()
def backup_params_and_grads(self) -> Tuple[Dict, Dict]:
updates, grads = {}, {}

for group in self.param_groups:
for p in group['params']:
updates[p] = p.clone()
grads[p] = p.grad.clone() if p.grad is not None else None

return updates, grads

@torch.no_grad()
def trac_step(self, updates: Dict, grads: Dict) -> None:
self.state['trac']['step'] += 1

deltas = {}

device = self.param_groups[0]['params'][0].device

h = torch.zeros((1,), device=device)
for group in self.param_groups:
for p in group['params']:
if grads[p] is None:
continue

theta_ref = self.state['trac'][p]
update = updates[p]

deltas[p] = (update - theta_ref) / torch.sum(self.state['trac']['s']).add_(self.eps)
update.neg_().add_(p)

grad, delta = grads[p], deltas[p]

product = torch.dot(delta.flatten(), grad.flatten())
h.add_(product)

delta.add_(update)

s = self.state['trac']['s']
betas = self.state['trac']['betas']
variance = self.state['trac']['variance']
sigma = self.state['trac']['sigma']

variance.mul_(betas.pow(2)).add_(h.pow(2))
sigma.mul_(betas).sub_(h)

s_term = self.erf_imag(sigma / (2.0 * variance).sqrt_().add_(self.eps))
s_term.mul_(self.f_term)
s.copy_(s_term)

scale = max(torch.sum(s), 0.0)

for group in self.param_groups:
for p in group['params']:
if grads[p] is None:
continue

delta = deltas[p]
delta.mul_(scale).add_(self.state['trac'][p])

p.copy_(delta)

@torch.no_grad()
def step(self, closure: CLOSURE = None) -> LOSS:
# TODO: backup is first to get the delta of param and grad, but it does not work.
with torch.enable_grad():
loss = self.optimizer.step(closure)

updates, grads = self.backup_params_and_grads()

if 'trac' not in self.state:
device = self.param_groups[0]['params'][0].device

self.state['trac'] = {
'betas': torch.tensor(self.betas, device=device),
's': torch.zeros(len(self.betas), device=device),
'variance': torch.zeros(len(self.betas), device=device),
'sigma': torch.full((len(self.betas),), 1e-8, device=device),
'step': 0,
}

for group in self.param_groups:
for p in group['params']:
self.state['trac'][p] = updates[p].clone()

self.trac_step(updates, grads)

return loss
1 change: 1 addition & 0 deletions tests/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@
'wsam',
'pcgrad',
'lookahead',
'trac',
]

SPARSE_OPTIMIZERS: List[str] = ['madgrad', 'dadaptadagrad', 'sm3']
Expand Down
1 change: 1 addition & 0 deletions tests/test_create_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ def test_create_optimizer():

create_optimizer(model, 'adamp', lr=1e-2, weight_decay=1e-3, use_gc=True, use_lookahead=True)
create_optimizer(model, 'alig', lr=1e-2, use_lookahead=True)
create_optimizer(model, 'adalomo', lr=1e-2, use_lookahead=False)


def test_bnb_optimizer():
Expand Down
2 changes: 1 addition & 1 deletion tests/test_general_optimizer_parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

@pytest.mark.parametrize('optimizer_name', VALID_OPTIMIZER_NAMES)
def test_learning_rate(optimizer_name):
if optimizer_name in ('alig', 'a2grad'):
if optimizer_name in {'alig', 'a2grad', 'adamw'}:
pytest.skip(f'skip {optimizer_name} optimizer')

optimizer = load_optimizer(optimizer_name)
Expand Down
12 changes: 7 additions & 5 deletions tests/test_gradients.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
import pytest
import torch

from pytorch_optimizer import SAM, WSAM, AdamP, Lookahead, load_optimizer
from pytorch_optimizer import SAM, TRAC, WSAM, AdamP, Lookahead, load_optimizer
from pytorch_optimizer.base.exception import NoSparseGradientError
from tests.constants import NO_SPARSE_OPTIMIZERS, SPARSE_OPTIMIZERS, VALID_OPTIMIZER_NAMES
from tests.utils import build_environment, simple_parameter, simple_sparse_parameter, sphere_loss


@pytest.mark.parametrize('optimizer_name', [*VALID_OPTIMIZER_NAMES, 'lookahead'])
@pytest.mark.parametrize('optimizer_name', [*VALID_OPTIMIZER_NAMES, 'lookahead', 'trac'])
def test_no_gradients(optimizer_name):
if optimizer_name in {'lomo', 'adalomo', 'adammini'}:
pytest.skip(f'skip {optimizer_name} optimizer.')
Expand All @@ -25,21 +25,23 @@ def test_no_gradients(optimizer_name):
elif optimizer_name in ('lamb', 'ralamb'):
optimizer = load_optimizer(optimizer_name)(params, pre_norm=True)
elif optimizer_name == 'lookahead':
optimizer = Lookahead(load_optimizer('adamp')(params), k=1)
optimizer = Lookahead(load_optimizer('adamw')(params), k=1)
elif optimizer_name == 'trac':
optimizer = TRAC(load_optimizer('adamw')(params))
else:
optimizer = load_optimizer(optimizer_name)(params)

optimizer.zero_grad()
sphere_loss(p1 + p3).backward(create_graph=True)

optimizer.step(lambda: 0.1) # for AliG optimizer
if optimizer_name != 'lookahead':
if optimizer_name not in {'lookahead', 'trac'}:
optimizer.zero_grad(set_to_none=True)


@pytest.mark.parametrize('no_sparse_optimizer', NO_SPARSE_OPTIMIZERS)
def test_sparse_not_supported(no_sparse_optimizer):
if no_sparse_optimizer in {'lomo', 'adalomo', 'bsam', 'adammini'}:
if no_sparse_optimizer in {'lomo', 'adalomo', 'bsam', 'adammini', 'adamw'}:
pytest.skip(f'skip {no_sparse_optimizer} optimizer.')

param = simple_sparse_parameter()[1]
Expand Down
Loading

0 comments on commit 1054960

Please sign in to comment.