Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Prodigy Plus Schedule Free optimizer #614

Open
wants to merge 12 commits into
base: master
Choose a base branch
from
19 changes: 16 additions & 3 deletions modules/ui/OptimizerParamsWindow.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,6 @@ def create_dynamic_ui(
'd_coef': {'title': 'D Coefficient', 'tooltip': 'Coefficient in the expression for the estimate of d.', 'type': 'float'},
'dampening': {'title': 'Dampening', 'tooltip': 'Dampening for optimizer_momentum.', 'type': 'float'},
'decay_rate': {'title': 'Decay Rate', 'tooltip': 'Rate of decay for moment estimation.', 'type': 'float'},
'decouple': {'title': 'Decouple', 'tooltip': 'Use AdamW style optimizer_decoupled weight decay.', 'type': 'bool'},
'differentiable': {'title': 'Differentiable', 'tooltip': 'Whether the optimization function is optimizer_differentiable.', 'type': 'bool'},
'eps': {'title': 'EPS', 'tooltip': 'A small value to prevent division by zero.', 'type': 'float'},
'eps2': {'title': 'EPS 2', 'tooltip': 'A small value to prevent division by zero.', 'type': 'float'},
Expand Down Expand Up @@ -142,9 +141,18 @@ def create_dynamic_ui(
'r': {'title': 'R', 'tooltip': 'EMA factor.', 'type': 'float'},
'adanorm': {'title': 'AdaNorm', 'tooltip': 'Whether to use the AdaNorm variant', 'type': 'bool'},
'adam_debias': {'title': 'Adam Debias', 'tooltip': 'Only correct the denominator to avoid inflating step sizes early in training.', 'type': 'bool'},
'cautious': {'title': 'Cautious', 'tooltip': 'Whether to use the Cautious variant.', 'type': 'bool'},

'split_groups': {'title': 'Split Groups', 'tooltip': 'Track individual adaptation values for each parameter group. Recommended: True', 'type': 'bool'},
'split_groups_mean': {'title': 'Split Groups Mean', 'tooltip': 'When split_groups is True, use the harmonic mean of learning rates for all groups. This favours a more conservative LR', 'type': 'bool'},
'factored': {'title': 'Factored', 'tooltip': 'Use factored approximation of the second moment, similar to Adafactor. Recommended: True', 'type': 'bool'},
'use_stableadamw': {'title': 'Use StableAdamW', 'tooltip': 'Scales parameter updates by the root-mean-square of the normalised gradient, in essence identical to Adafactors gradient scaling. Recommended: True', 'type': 'bool'},
'use_muon_pp': {'title': 'Use Muon++', 'tooltip': 'Whether to use Muon++ variant.', 'type': 'bool'},
'use_cautious': {'title': 'Use Cautious', 'tooltip': 'Experimental. Perform "cautious" updates, as proposed in https://arxiv.org/pdf/2411.16085. Recommended: False', 'type': 'bool'},
'use_adopt': {'title': 'Use ADOPT', 'tooltip': 'Experimental. Partial implementation of (https://arxiv.org/abs/2411.02853). Recommended: False', 'type': 'bool'},
'lr': {'title': 'Learning Rate', 'tooltip': 'Learning rate adjustment parameter. Increases or decreases the Prodigy learning rate. Recommended: 1.0', 'type': 'float'},
'weignt_decay_by_lr': {'title': 'Weight Decay by LR', 'tooltip': 'If True, weight_decay is multiplied by the adaptive learning rate. Recommended: True', 'type': 'bool'},
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

weignt_decay_by_lr -> weight_decay_by_lr

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually this is in a couple of places.

'prodigy_steps': {'title': 'Prodigy Steps', 'tooltip': 'Freeze Prodigy stepsize adjustments after a certain optimiser step and releases all state memory required. Reccomended: 25% total num steps', 'type': 'int'},
}

# @formatter:on

if not self.winfo_exists(): # check if this window isn't open
Expand All @@ -154,6 +162,11 @@ def create_dynamic_ui(

# Extract the keys for the selected optimizer
for index, key in enumerate(OPTIMIZER_DEFAULT_PARAMETERS[selected_optimizer].keys()):
if selected_optimizer == Optimizer.PRODIGY_PLUS_SCHEDULE_FREE and key not in [
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you explain this change? I'm not quite sure if or why it's needed

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@saunderez Can you comment on this? I don't want to merge something if I don't understand why it's done

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is a side effect of setting the lr to 1.0 in the OPTIMIZER_DEFAULT_PARAMETERS and not wanting to present it as configurable.

            if selected_optimizer == Optimizer.PRODIGY_PLUS_SCHEDULE_FREE and key in [
                'lr'
            ]:
                continue

Is really what it's doing here.

given the other learning rate free optimizers don't do this it's probably better if lr is removed from Optimizer.PRODIGY_PLUS_SCHEDULE_FREE and then conditional isn't necessary.

Out of scope for this change would be to fix this sharp edge for all the optimizers that expect a lr of 1.0

'beta1', 'beta2', 'eps', 'weight_decay', 'use_bias_correction', 'safeguard_warmup', 'd0', 'd_coef', 'growth_rate', 'fsdp_in_use', 'split_groups', 'split_groups_mean', 'factored', 'fused_back_pass', 'use_stableadamw', 'use_muon_pp', 'use_cautious', 'use_adopt', 'weignt_decay_by_lr', 'prodigy_steps'
]:
continue

arg_info = KEY_DETAIL_MAP[key]

title = arg_info['title']
Expand Down
26 changes: 26 additions & 0 deletions modules/util/config/TrainConfig.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,24 @@ class TrainOptimizerConfig(BaseConfig):
adanorm: bool
adam_debias: bool
cautious: bool
split_groups: bool
split_groups_mean: bool
factored: bool
use_stableadamw: bool
use_muon_pp: bool
use_cautious: bool
use_adopt: bool
prodigy_steps: int
use_adopt: bool
use_cautious: bool
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this supposed to be here twice?

use_muon_pp: bool
use_stableadamw: bool
weight_decay_by_lr: bool
factored: bool
split_groups: bool
split_groups_mean: bool
fused_back_pass: bool


def __init__(self, data: list[(str, Any, type, bool)]):
super().__init__(data)
Expand Down Expand Up @@ -158,6 +176,14 @@ def default_values():
data.append(("adanorm", False, bool, False))
data.append(("adam_debias", False, bool, False))
data.append(("cautious", False, bool, False))
data.append(("split_groups", True, bool, False))
data.append(("split_groups_mean", True, bool, False))
data.append(("factored", True, bool, False))
data.append(("use_stableadamw", True, bool, False))
data.append(("use_muon_pp", False, bool, False))
data.append(("use_cautious", False, bool, False))
data.append(("use_adopt", False, bool, False))
data.append(("prodigy_steps", 0, int, False))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

missing weight_decay_by_lr


return TrainOptimizerConfig(data)

Expand Down
34 changes: 28 additions & 6 deletions modules/util/create.py
Original file line number Diff line number Diff line change
Expand Up @@ -529,7 +529,7 @@ def create_optimizer(
lr=config.learning_rate,
betas=(optimizer_config.beta1 if optimizer_config.beta1 is not None else 0.9,
optimizer_config.beta2 if optimizer_config.beta2 is not None else 0.999,
optimizer_config.beta3 if optimizer_config.beta1 is not None else 0.9999,),
optimizer_config.beta3 if optimizer_config.beta3 is not None else 0.9999),
weight_decay=optimizer_config.weight_decay if optimizer_config.weight_decay is not None else 1e-2,
eps=optimizer_config.eps if optimizer_config.eps is not None else 1e-8,
alpha=optimizer_config.alpha if optimizer_config.alpha is not None else 5,
Expand Down Expand Up @@ -759,7 +759,6 @@ def create_optimizer(
eps=optimizer_config.eps if optimizer_config.eps is not None else 1e-8,
weight_decay=optimizer_config.weight_decay if optimizer_config.weight_decay is not None else 0,
log_every=optimizer_config.log_every if optimizer_config.log_every is not None else 0,
decouple=optimizer_config.decouple if optimizer_config.decouple is not None else False,
use_bias_correction=optimizer_config.use_bias_correction if optimizer_config.use_bias_correction is not None else False,
d0=optimizer_config.d0 if optimizer_config.d0 is not None else 1e-6,
growth_rate=optimizer_config.growth_rate if optimizer_config.growth_rate is not None else float('inf'),
Expand Down Expand Up @@ -818,11 +817,8 @@ def create_optimizer(
params=parameters,
lr=config.learning_rate,
betas=(optimizer_config.beta1 if optimizer_config.beta1 is not None else 0.9,
optimizer_config.beta2 if optimizer_config.beta2 is not None else 0.999),
beta3=optimizer_config.beta3 if optimizer_config.beta3 is not None else None,
optimizer_config.beta2 if optimizer_config.beta2 is not None else 0.999),
eps=optimizer_config.eps if optimizer_config.eps is not None else 1e-8,
weight_decay=optimizer_config.weight_decay if optimizer_config.weight_decay is not None else 0,
decouple=optimizer_config.decouple if optimizer_config.decouple is not None else True,
use_bias_correction=optimizer_config.use_bias_correction if optimizer_config.use_bias_correction is not None else False,
safeguard_warmup=optimizer_config.safeguard_warmup if optimizer_config.safeguard_warmup is not None else False,
d0=optimizer_config.d0 if optimizer_config.d0 is not None else 1e-6,
Expand All @@ -831,6 +827,32 @@ def create_optimizer(
fsdp_in_use=optimizer_config.fsdp_in_use if optimizer_config.fsdp_in_use is not None else False,
)

# PRODIGY_PLUS_SCHEDULE_FREE Optimizer
case Optimizer.PRODIGY_PLUS_SCHEDULE_FREE:
from prodigyplus import ProdigyPlusScheduleFree
optimizer = ProdigyPlusScheduleFree(
params=parameters,
lr=config.learning_rate,
betas=(optimizer_config.beta1 if optimizer_config.beta1 is not None else 0.9,
optimizer_config.beta2 if optimizer_config.beta2 is not None else 0.999),
eps=optimizer_config.eps if optimizer_config.eps is not None else 1e-8,
weight_decay=optimizer_config.weight_decay if optimizer_config.weight_decay is not None else 0,
use_bias_correction=optimizer_config.use_bias_correction if optimizer_config.use_bias_correction is not None else False,
d0=optimizer_config.d0 if optimizer_config.d0 is not None else 1e-6,
d_coef=optimizer_config.d_coef if optimizer_config.d_coef is not None else 1.0,
prodigy_steps=optimizer_config.prodigy_steps if optimizer_config.prodigy_steps is not None else 0,
split_groups=optimizer_config.split_groups if optimizer_config.split_groups is not None else True,
split_groups_mean=optimizer_config.split_groups_mean if optimizer_config.split_groups_mean is not None else True,
factored=optimizer_config.factored if optimizer_config.factored is not None else True,
fused_back_pass=optimizer_config.fused_back_pass if optimizer_config.fused_back_pass is not None else False,
use_stableadamw=optimizer_config.use_stableadamw if optimizer_config.use_stableadamw is not None else True,
use_muon_pp=optimizer_config.use_muon_pp if optimizer_config.use_muon_pp is not None else False,
use_cautious=optimizer_config.use_cautious if optimizer_config.use_cautious is not None else False,
use_adopt=optimizer_config.use_adopt if optimizer_config.use_adopt is not None else False,
stochastic_rounding=optimizer_config.stochastic_rounding if optimizer_config.stochastic_rounding is not None else True,
weight_decay_by_lr=optimizer_config.weight_decay_by_lr if optimizer_config.weight_decay_by_lr is not None else True,
)

# ADAFactor Optimizer
case Optimizer.ADAFACTOR:
from transformers.optimization import Adafactor
Expand Down
4 changes: 4 additions & 0 deletions modules/util/enum/Optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ class Optimizer(Enum):

# Prodigy
PRODIGY = 'PRODIGY'
PRODIGY_PLUS_SCHEDULE_FREE = 'PRODIGY_PLUS_SCHEDULE_FREE'

# ADAFACTOR
ADAFACTOR = 'ADAFACTOR'
Expand All @@ -73,13 +74,15 @@ def is_adaptive(self):
self.DADAPT_ADA_GRAD,
self.DADAPT_LION,
self.PRODIGY,
self.PRODIGY_PLUS_SCHEDULE_FREE,
]

@property
def is_schedule_free(self):
return self in [
self.SCHEDULE_FREE_ADAMW,
self.SCHEDULE_FREE_SGD,
self.PRODIGY_PLUS_SCHEDULE_FREE,
]

def supports_fused_back_pass(self):
Expand All @@ -88,6 +91,7 @@ def supports_fused_back_pass(self):
Optimizer.CAME,
Optimizer.ADAM,
Optimizer.ADAMW,
Optimizer.PRODIGY_PLUS_SCHEDULE_FREE,
]

# Small helper for adjusting learning rates to adaptive optimizers.
Expand Down
28 changes: 25 additions & 3 deletions modules/util/optimizer_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,9 +42,9 @@ def update_optimizer_config(train_config: TrainConfig):
saved_optimizer_config = train_config.optimizer_defaults[str(optimizer)]
saved_optimizer_config.from_dict(train_config.optimizer.to_dict())
else:
optimizer_donfig = TrainOptimizerConfig.default_values()
optimizer_donfig.from_dict(train_config.optimizer.to_dict())
train_config.optimizer_defaults[str(optimizer)] = optimizer_donfig
optimizer_config = TrainOptimizerConfig.default_values()
optimizer_config.from_dict(train_config.optimizer.to_dict())
train_config.optimizer_defaults[str(optimizer)] = optimizer_config


def init_model_parameters(
Expand Down Expand Up @@ -270,6 +270,28 @@ def init_model_parameters(
"growth_rate": float('inf'),
"fsdp_in_use": False,
},
Optimizer.PRODIGY_PLUS_SCHEDULE_FREE: {
"lr": 1.0,
"beta1": 0.9,
"beta2": 0.999,
"beta3": None,
"weight_decay": 0.0,
"weignt_decay_by_lr": True,
"use_bias_correction": False,
"d0": 1e-6,
"d_coef": 1.0,
"prodigy_steps": 0,
"eps": 1e-8,
"split_groups": True,
"split_groups_mean": True,
"factored": True,
"fused_back_pass": False,
"use_stableadamw": True,
"use_muon_pp": False,
"use_cautious": False,
"use_adopt": False,
"stochastic_rounding": True,
},
Optimizer.DADAPT_ADA_GRAD: {
"momentum": 0,
"log_every": 0,
Expand Down
1 change: 1 addition & 0 deletions requirements-global.txt
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ lion-pytorch==0.2.2 # lion optimizer
prodigyopt==1.0 # prodigy optimizer
schedulefree==1.3.0 # schedule-free optimizers
pytorch_optimizer==3.3.0 # pytorch optimizers
prodigy-plus-schedule-free==1.8.0
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

1.9.0 got released rather recently. Only interface change is an extra parameter, factored_fp32 with a default of True.


# Profiling
scalene==1.5.45
Expand Down