Skip to content

Commit

Permalink
Merge pull request #302 from kozistr/feature/muon-optimizer
Browse files Browse the repository at this point in the history
[Feature] Implement Muon optimizer
  • Loading branch information
kozistr authored Dec 3, 2024
2 parents c341872 + ecaf786 commit a980dc0
Show file tree
Hide file tree
Showing 15 changed files with 331 additions and 13 deletions.
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

**pytorch-optimizer** is optimizer & lr scheduler collections in PyTorch.
I just re-implemented (speed & memory tweaks, plug-ins) the algorithm while based on the original paper. Also, It includes useful and practical optimization ideas.
Currently, **80 optimizers (+ `bitsandbytes`, `qgalore`, `torchao`)**, **16 lr schedulers**, and **13 loss functions** are supported!
Currently, **81 optimizers (+ `bitsandbytes`, `qgalore`, `torchao`)**, **16 lr schedulers**, and **13 loss functions** are supported!

Highly inspired by [pytorch-optimizer](https://github.com/jettify/pytorch-optimizer).

Expand Down Expand Up @@ -186,6 +186,7 @@ get_supported_optimizers(['adam*', 'ranger*'])
| Cautious | *Improving Training with One Line of Code* | [github](https://github.com/kyleliang919/C-Optim) | <https://arxiv.org/pdf/2411.16085v1> | [cite](https://github.com/kyleliang919/C-Optim?tab=readme-ov-file#citation) |
| DeMo | *Decoupled Momentum Optimization* | [github](https://github.com/bloc97/DeMo) | <https://arxiv.org/abs/2411.19870> | [cite](https://ui.adsabs.harvard.edu/abs/2024arXiv241119870P/exportcitation) |
| MicroAdam | *Accurate Adaptive Optimization with Low Space Overhead and Provable Convergence* | [github](https://github.com/IST-DASLab/MicroAdam) | <https://arxiv.org/abs/2405.15593> | [cite](https://github.com/IST-DASLab/MicroAdam?tab=readme-ov-file#citing) |
| Muon | *MomentUm Orthogonalized by Newton-schulz* | [github](https://github.com/KellerJordan/Muon) | <https://x.com/kellerjordan0/status/1842300916864844014> | [cite](https://github.com/KellerJordan/Muon) |

## Supported LR Scheduler

Expand Down
2 changes: 2 additions & 0 deletions docs/changelogs/v3.3.1.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,5 @@

* Implement `DeMo` optimizer. (#300, #301)
* [Decoupled Momentum Optimization](https://arxiv.org/abs/2411.19870)
* Implement `Muon` optimizer. (#302)
* [MomentUm Orthogonalized by Newton-schulz](https://github.com/KellerJordan/Muon)
3 changes: 2 additions & 1 deletion docs/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

**pytorch-optimizer** is optimizer & lr scheduler collections in PyTorch.
I just re-implemented (speed & memory tweaks, plug-ins) the algorithm while based on the original paper. Also, It includes useful and practical optimization ideas.
Currently, **80 optimizers (+ `bitsandbytes`, `qgalore`, `torchao`)**, **16 lr schedulers**, and **13 loss functions** are supported!
Currently, **81 optimizers (+ `bitsandbytes`, `qgalore`, `torchao`)**, **16 lr schedulers**, and **13 loss functions** are supported!

Highly inspired by [pytorch-optimizer](https://github.com/jettify/pytorch-optimizer).

Expand Down Expand Up @@ -186,6 +186,7 @@ get_supported_optimizers(['adam*', 'ranger*'])
| Cautious | *Improving Training with One Line of Code* | [github](https://github.com/kyleliang919/C-Optim) | <https://arxiv.org/pdf/2411.16085v1> | [cite](https://github.com/kyleliang919/C-Optim?tab=readme-ov-file#citation) |
| DeMo | *Decoupled Momentum Optimization* | [github](https://github.com/bloc97/DeMo) | <https://arxiv.org/abs/2411.19870> | [cite](https://ui.adsabs.harvard.edu/abs/2024arXiv241119870P/exportcitation) |
| MicroAdam | *Accurate Adaptive Optimization with Low Space Overhead and Provable Convergence* | [github](https://github.com/IST-DASLab/MicroAdam) | <https://arxiv.org/abs/2405.15593> | [cite](https://github.com/IST-DASLab/MicroAdam?tab=readme-ov-file#citing) |
| Muon | *MomentUm Orthogonalized by Newton-schulz* | [github](https://github.com/KellerJordan/Muon) | <https://x.com/kellerjordan0/status/1842300916864844014> | [cite](https://github.com/KellerJordan/Muon) |

## Supported LR Scheduler

Expand Down
4 changes: 4 additions & 0 deletions docs/optimizer.md
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,10 @@
:docstring:
:members:

::: pytorch_optimizer.Muon
:docstring:
:members:

::: pytorch_optimizer.Nero
:docstring:
:members:
Expand Down
11 changes: 6 additions & 5 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,12 @@ keywords = [
"AdaHessian", "Adai", "Adalite", "AdaLomo", "AdamMini", "AdamP", "AdamS", "Adan", "AggMo", "Aida", "AliG", "Amos",
"Apollo", "AvaGrad", "bSAM", "CAME", "DAdaptAdaGrad", "DAdaptAdam", "DAdaptAdan", "DAdaptSGD", "DAdaptLion",
"DeMo", "DiffGrad", "FAdam", "Fromage", "FTRL", "GaLore", "Gravity", "GrokFast", "GSAM", "Kate", "Lamb", "LARS",
"Lion", "LOMO", "Lookahead", "MADGRAD", "MSVAG", "Nero", "NovoGrad", "PAdam", "PCGrad", "PID", "PNM", "Prodigy",
"QHAdam", "QHM", "RAdam", "Ranger", "Ranger21", "RotoGrad", "SAM", "ScheduleFreeSGD", "ScheduleFreeAdamW", "SGDP",
"Shampoo", "ScalableShampoo", "SGDW", "SignSGD", "SM3", "SOAP", "SopihaH", "SRMM", "StableAdamW", "SWATS", "Tiger",
"TRAC", "WSAM", "Yogi", "BCE", "BCEFocal", "Focal", "FocalCosine", "SoftF1", "Dice", "LDAM", "Jaccard",
"Bi-Tempered", "Tversky", "FocalTversky", "LovaszHinge", "bitsandbytes", "WSD", "QGaLore",
"Lion", "LOMO", "Lookahead", "MADGRAD", "MSVAG", "Muno", "Nero", "NovoGrad", "PAdam", "PCGrad", "PID", "PNM",
"Prodigy", "QHAdam", "QHM", "RAdam", "Ranger", "Ranger21", "RotoGrad", "SAM", "ScheduleFreeSGD",
"ScheduleFreeAdamW", "SGDP", "Shampoo", "ScalableShampoo", "SGDW", "SignSGD", "SM3", "SOAP", "SopihaH", "SRMM",
"StableAdamW", "SWATS", "Tiger", "TRAC", "WSAM", "Yogi", "BCE", "BCEFocal", "Focal", "FocalCosine", "SoftF1",
"Dice", "LDAM", "Jaccard", "Bi-Tempered", "Tversky", "FocalTversky", "LovaszHinge", "bitsandbytes", "WSD",
"QGaLore",
]
classifiers = [
"License :: OSI Approved :: Apache Software License",
Expand Down
1 change: 1 addition & 0 deletions pytorch_optimizer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,7 @@
Lamb,
Lion,
Lookahead,
Muon,
Nero,
NovoGrad,
PAdam,
Expand Down
3 changes: 3 additions & 0 deletions pytorch_optimizer/optimizer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@
from pytorch_optimizer.optimizer.lookahead import Lookahead
from pytorch_optimizer.optimizer.madgrad import MADGRAD
from pytorch_optimizer.optimizer.msvag import MSVAG
from pytorch_optimizer.optimizer.muon import Muon
from pytorch_optimizer.optimizer.nero import Nero
from pytorch_optimizer.optimizer.novograd import NovoGrad
from pytorch_optimizer.optimizer.padam import PAdam
Expand Down Expand Up @@ -272,6 +273,8 @@ def load_optimizer(optimizer: str) -> OPTIMIZER:
SOAP,
ADOPT,
FTRL,
DeMo,
Muon,
]
OPTIMIZERS: Dict[str, OPTIMIZER] = {str(optimizer.__name__).lower(): optimizer for optimizer in OPTIMIZER_LIST}

Expand Down
5 changes: 4 additions & 1 deletion pytorch_optimizer/optimizer/demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,11 +302,14 @@ def __init__(
process_group: Optional[ProcessGroup] = None,
**kwargs,
):
self.validate_learning_rate(lr)
self.validate_non_negative(weight_decay, 'weight_decay')
self.validate_range(compression_decay, 'compression_decay', 0.0, 1.0, range_type='[)')
self.validate_positive(compression_top_k, 'compression_top_k')
self.validate_positive(compression_chunk, 'compression_chunk')

self.weight_decay = weight_decay

self.compression_decay = compression_decay
self.compression_top_k = compression_top_k
self.compression_chunk = compression_chunk
Expand Down Expand Up @@ -406,7 +409,7 @@ def step(self, closure: CLOSURE = None) -> LOSS:
p,
grad,
lr=lr,
weight_decay=group['weight_decay'],
weight_decay=self.weight_decay,
weight_decouple=True,
fixed_decay=False,
)
Expand Down
253 changes: 253 additions & 0 deletions pytorch_optimizer/optimizer/muon.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,253 @@
import os
from typing import List, Optional, Tuple

import torch
from torch.distributed import ReduceOp, all_reduce

from pytorch_optimizer.base.exception import NoSparseGradientError
from pytorch_optimizer.base.optimizer import BaseOptimizer
from pytorch_optimizer.base.types import BETAS, CLOSURE, DEFAULTS, LOSS, PARAMETERS


def zero_power_via_newton_schulz_5(
g: torch.Tensor, num_steps: int = 10, eps: float = 1e-7, weights: Tuple[int, int, int] = (3.4445, -4.7750, 2.0315)
) -> torch.Tensor:
r"""Compute the zeroth power / orthogonalization of G.
Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. We opt to use a quintic iteration
whose coefficients are selected to maximize the slope at zero. For the purpose of minimizing steps, it turns out
to be empirically effective to keep increasing the slope at zero even beyond the point where the iteration no
longer converges all the way to one everywhere on the interval. This iteration therefore does not produce UV^T but
rather something like US'V^T where S' is diagonal with S_{ii}' ~ Uniform(0.5, 1.5), which turns out not to hurt
model performance at all relative to UV^T, where USV^T = G is the SVD.
:param g: torch.Tensor. matrix.
:param num_steps: int. number of iterations.
:param eps: float. add this times I to G, to make is positive definite. For scaling, we multiply it by the largest
eigenvalue of G.
:param weights: Tuple[int, int, int]. weights.
"""
if len(g.shape) != 2:
raise ValueError('shape of g must be 2-dimensional')

x = g.bfloat16()
x.div_(x.norm().add_(eps))

if g.size(0) > g.size(1):
x = x.T

for _ in range(num_steps):
a = x @ x.T
b = weights[1] * a + weights[2] * a @ a
x = weights[0] * x + b @ x

if g.size(0) > g.size(1):
x = x.T

return x


class Muon(BaseOptimizer):
r"""MomentUm Orthogonalized by Newton-schulz.
Muon internally runs standard SGD-momentum, and then performs an orthogonalization post-processing step, in which
each 2D parameter's update is replaced with the nearest orthogonal matrix. To efficiently orthogonalize each
update, we use a Newton-Schulz iteration, which has the advantage that it can be stably run in bfloat16 on the GPU.
Some warnings:
- We believe this optimizer is unlikely to work well for training with small batch size.
- We believe it may not work well for fine-tuning pretrained models, but we haven't tested this.
:param params: PARAMETERS. the parameters to be optimized by Muon.
:param lr: float. learning rate.
:param momentum: float. the momentum used by the internal SGD.
:param betas: The betas for the internal AdamW.
:param nesterov: bool. whether to use nesterov momentum.
:param ns_steps: int. the number of Newton-Schulz iterations to run. (6 is probably always enough)
:param adamw_params: The parameters to be optimized by AdamW. Any parameters in `muon_params` which are {0, 1}-D or
are detected as being the embed or lm_head will be optimized by AdamW as well.
:param adamw_lr: The learning rate for the internal AdamW.
:param adamw_wd: The weight decay for the internal AdamW.
:param adamw_eps: The epsilon for the internal AdamW.
"""

def __init__(
self,
params: PARAMETERS,
lr: float = 2e-2,
momentum: float = 0.95,
betas: BETAS = (0.95, 0.95),
nesterov: bool = True,
ns_steps: int = 6,
adamw_params: Optional[PARAMETERS] = None,
adamw_lr: float = 3e-4,
adamw_wd: float = 0,
adamw_eps: float = 1e-8,
**kwargs,
):
self.validate_learning_rate(lr)
self.validate_learning_rate(adamw_lr)
self.validate_range(momentum, 'momentum', 0.0, 1.0, range_type='[)')
self.validate_positive(ns_steps, 'ns_steps')
self.validate_betas(betas)
self.validate_non_negative(adamw_wd, 'adamw_wd')
self.validate_non_negative(adamw_eps, 'adamw_eps')

params = self.get_parameters(params)
adamw_params = self.get_parameters(adamw_params) if adamw_params is not None else []
params.extend(adamw_params)

self.world_size: int = int(os.environ.get('WORLD_SIZE', 1))
self.rank: int = int(os.environ.get('RANK', 0))

defaults: DEFAULTS = {
'lr': lr,
'momentum': momentum,
'nesterov': nesterov,
'ns_steps': ns_steps,
'adamw_lr': adamw_lr,
'adamw_lr_ratio': adamw_lr / lr,
'adamw_betas': betas,
'adamw_wd': adamw_wd,
'adamw_eps': adamw_eps,
}
super().__init__(params, defaults)

self.set_muon_state(params, adamw_params)

def __str__(self) -> str:
return 'Muon'

@staticmethod
def get_parameters(params: PARAMETERS) -> List[torch.Tensor]:
if isinstance(params, list) and isinstance(params[0], torch.Tensor):
return params

new_params = []
for group in params:
if isinstance(group, dict) and 'params' in group:
new_params.extend(list(group['params']))
else:
new_params.append(group)

return new_params

def set_muon_state(self, params: PARAMETERS, adamw_params: PARAMETERS, threshold: int = 8192) -> None:
r"""Set use_muon flag."""
for p in params:
self.state[p]['use_muon'] = p.ndim >= 2 and p.size(0) < threshold

for p in adamw_params:
self.state[p]['use_muon'] = False

@torch.no_grad()
def reset(self):
for group in self.param_groups:
group['step'] = 0
for p in group['params']:
state = self.state[p]

state['momentum_buffer'] = torch.zeros_like(p)
state['moment1'] = torch.zeros_like(p)
state['moment2'] = torch.zeros_like(p)

@torch.no_grad()
def step(self, closure: CLOSURE = None) -> LOSS:
loss: LOSS = None
if closure is not None:
with torch.enable_grad():
loss = closure()

for group in self.param_groups:
if 'step' in group:
group['step'] += 1
else:
group['step'] = 1

params = []
for p in group['params']:
if p.grad is not None and self.state[p]['use_muon']:
if p.grad.is_sparse:
raise NoSparseGradientError(str(self))
params.append(p)

if len(params) == 0:
continue

lr = group['lr']
momentum = group['momentum']

total_params: int = sum(p.numel() for p in params)
updates_flat = torch.zeros(total_params, device=params[0].device, dtype=torch.bfloat16)
curr_idx: int = 0

for i, p in enumerate(params):
if i % self.world_size != self.rank:
curr_idx += p.numel()
continue

g = p.grad
if g.ndim > 2:
g = g.view(g.size(0), -1)

state = self.state[p]
if 'momentum_buffer' not in state:
state['momentum_buffer'] = torch.zeros_like(g)

buf = state['momentum_buffer']
buf.mul_(momentum).add_(g)

if group['nesterov']:
g.add_(buf, alpha=momentum)
else:
g = buf

g = zero_power_via_newton_schulz_5(g, num_steps=group['ns_steps'])
g.mul_(max(1.0, g.size(0) / g.size(1)) ** 0.5)

updates_flat[curr_idx:curr_idx + p.numel()] = g.flatten() # fmt: skip

if self.world_size > 1: # pragma: no cover
all_reduce(updates_flat, op=ReduceOp.SUM)

curr_idx: int = 0
for p in params:
g = updates_flat[curr_idx:curr_idx + p.numel()].view_as(p).type_as(p) # fmt: skip
p.add_(g, alpha=-lr)
curr_idx += p.numel()

params = [p for p in group['params'] if p.grad is not None and not self.state[p]['use_muon']]

lr: float = group['adamw_lr_ratio'] * group['lr']
beta1, beta2 = group['adamw_betas']

bias_correction1: float = self.debias(beta1, group['step'])
bias_correction2: float = self.debias(beta2, group['step'])
scale: float = bias_correction1 / bias_correction2 ** 0.5 # fmt: skip
step_size: float = lr / scale

for p in params:
grad = p.grad
state = self.state[p]
if 'moment1' not in state:
state['moment1'] = torch.zeros_like(grad)
state['moment2'] = torch.zeros_like(grad)

buf1, buf2 = state['moment1'], state['moment2']
buf1.lerp_(grad, weight=1.0 - beta1)
buf2.lerp_(grad.square(), weight=1.0 - beta2)

update = buf1 / buf2.sqrt().add_(group['adamw_eps'])

self.apply_weight_decay(
p,
grad,
lr=lr,
weight_decay=group['adamw_wd'],
weight_decouple=True,
fixed_decay=False,
)

p.add_(update, alpha=-step_size)

return loss
4 changes: 4 additions & 0 deletions tests/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@
Kate,
Lamb,
Lion,
Muon,
Nero,
NovoGrad,
PAdam,
Expand Down Expand Up @@ -144,6 +145,7 @@
'adamg',
'ademamix',
'soap',
'muon',
]

VALID_LR_SCHEDULER_NAMES: List[str] = [
Expand Down Expand Up @@ -495,6 +497,8 @@
),
(ADOPT, {'lr': 1e0}, 5),
(FTRL, {'lr': 1e0, 'beta': 0.0, 'lambda_1': 0.0, 'lambda_2': 0.0}, 5),
(Muon, {'lr': 1e0, 'ns_steps': 6, 'adam_lr': 1e0, 'adamw_wd': 1e-2}, 5),
(Muon, {'lr': 1e0, 'ns_steps': 6, 'adam_lr': 1e0, 'adamw_wd': 1e-2, 'nesterov': False}, 5),
]
ADANORM_SUPPORTED_OPTIMIZERS: List[Tuple[Any, Dict[str, Union[float, bool, int]], int]] = [
(AdaBelief, {'lr': 5e-1, 'weight_decay': 1e-3, 'adanorm': True}, 10),
Expand Down
Loading

0 comments on commit a980dc0

Please sign in to comment.