Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Prodigy Plus Schedule Free optimizer #614

Open
wants to merge 12 commits into
base: master
Choose a base branch
from
16 changes: 13 additions & 3 deletions modules/ui/OptimizerParamsWindow.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,6 @@ def create_dynamic_ui(
'd_coef': {'title': 'D Coefficient', 'tooltip': 'Coefficient in the expression for the estimate of d.', 'type': 'float'},
'dampening': {'title': 'Dampening', 'tooltip': 'Dampening for optimizer_momentum.', 'type': 'float'},
'decay_rate': {'title': 'Decay Rate', 'tooltip': 'Rate of decay for moment estimation.', 'type': 'float'},
'decouple': {'title': 'Decouple', 'tooltip': 'Use AdamW style optimizer_decoupled weight decay.', 'type': 'bool'},
'differentiable': {'title': 'Differentiable', 'tooltip': 'Whether the optimization function is optimizer_differentiable.', 'type': 'bool'},
'eps': {'title': 'EPS', 'tooltip': 'A small value to prevent division by zero.', 'type': 'float'},
'eps2': {'title': 'EPS 2', 'tooltip': 'A small value to prevent division by zero.', 'type': 'float'},
Expand Down Expand Up @@ -142,8 +141,14 @@ def create_dynamic_ui(
'r': {'title': 'R', 'tooltip': 'EMA factor.', 'type': 'float'},
'adanorm': {'title': 'AdaNorm', 'tooltip': 'Whether to use the AdaNorm variant', 'type': 'bool'},
'adam_debias': {'title': 'Adam Debias', 'tooltip': 'Only correct the denominator to avoid inflating step sizes early in training.', 'type': 'bool'},
'cautious': {'title': 'Cautious', 'tooltip': 'Whether to use the Cautious variant.', 'type': 'bool'},

'cautious': {'title': 'Cautious', 'tooltip': 'Whether to use the Cautious variant.', 'type': 'bool'},
'split_groups': {'title': 'Split Groups', 'tooltip': 'Whether to split parameter groups.', 'type': 'bool'},
'split_groups_mean': {'title': 'Split Groups Mean', 'tooltip': 'Whether to use mean for split groups.', 'type': 'bool'},
'factored': {'title': 'Factored', 'tooltip': 'Whether to use factored updates.', 'type': 'bool'},
'use_stableadamw': {'title': 'Use StableAdamW', 'tooltip': 'Whether to use StableAdamW variant.', 'type': 'bool'},
'use_muon_pp': {'title': 'Use Muon++', 'tooltip': 'Whether to use Muon++ variant.', 'type': 'bool'},
'use_cautious': {'title': 'Use Cautious', 'tooltip': 'Whether to use Cautious variant.', 'type': 'bool'},
'use_adopt': {'title': 'Use ADOPT', 'tooltip': 'Whether to use ADOPT variant.', 'type': 'bool'},
}
# @formatter:on

Expand All @@ -154,6 +159,11 @@ def create_dynamic_ui(

# Extract the keys for the selected optimizer
for index, key in enumerate(OPTIMIZER_DEFAULT_PARAMETERS[selected_optimizer].keys()):
if selected_optimizer == Optimizer.PRODIGY_PLUS_SCHEDULE_FREE and key not in [
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you explain this change? I'm not quite sure if or why it's needed

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@saunderez Can you comment on this? I don't want to merge something if I don't understand why it's done

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is a side effect of setting the lr to 1.0 in the OPTIMIZER_DEFAULT_PARAMETERS and not wanting to present it as configurable.

            if selected_optimizer == Optimizer.PRODIGY_PLUS_SCHEDULE_FREE and key in [
                'lr'
            ]:
                continue

Is really what it's doing here.

given the other learning rate free optimizers don't do this it's probably better if lr is removed from Optimizer.PRODIGY_PLUS_SCHEDULE_FREE and then conditional isn't necessary.

Out of scope for this change would be to fix this sharp edge for all the optimizers that expect a lr of 1.0

'beta1', 'beta2', 'eps', 'weight_decay', 'use_bias_correction', 'safeguard_warmup', 'd0', 'd_coef', 'growth_rate', 'fsdp_in_use', 'split_groups', 'split_groups_mean', 'factored', 'fused_back_pass', 'use_stableadamw', 'use_muon_pp', 'use_cautious', 'use_adopt'
]:
continue

arg_info = KEY_DETAIL_MAP[key]

title = arg_info['title']
Expand Down
16 changes: 16 additions & 0 deletions modules/util/config/TrainConfig.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,14 @@ class TrainOptimizerConfig(BaseConfig):
adanorm: bool
adam_debias: bool
cautious: bool
split_groups: bool
split_groups_mean: bool
factored: bool
use_stableadamw: bool
use_muon_pp: bool
use_cautious: bool
use_adopt: bool
prodigy_steps: int

def __init__(self, data: list[(str, Any, type, bool)]):
super().__init__(data)
Expand Down Expand Up @@ -158,6 +166,14 @@ def default_values():
data.append(("adanorm", False, bool, False))
data.append(("adam_debias", False, bool, False))
data.append(("cautious", False, bool, False))
data.append(("split_groups", True, bool, False))
data.append(("split_groups_mean", True, bool, False))
data.append(("factored", True, bool, False))
data.append(("use_stableadamw", True, bool, False))
data.append(("use_muon_pp", False, bool, False))
data.append(("use_cautious", False, bool, False))
data.append(("use_adopt", False, bool, False))
data.append(("prodigy_steps", 0, int, False))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

missing weight_decay_by_lr


return TrainOptimizerConfig(data)

Expand Down
36 changes: 31 additions & 5 deletions modules/util/create.py
Original file line number Diff line number Diff line change
Expand Up @@ -529,7 +529,7 @@ def create_optimizer(
lr=config.learning_rate,
betas=(optimizer_config.beta1 if optimizer_config.beta1 is not None else 0.9,
optimizer_config.beta2 if optimizer_config.beta2 is not None else 0.999,
optimizer_config.beta3 if optimizer_config.beta1 is not None else 0.9999,),
optimizer_config.beta3 if optimizer_config.beta3 is not None else 0.9999),
weight_decay=optimizer_config.weight_decay if optimizer_config.weight_decay is not None else 1e-2,
eps=optimizer_config.eps if optimizer_config.eps is not None else 1e-8,
alpha=optimizer_config.alpha if optimizer_config.alpha is not None else 5,
Expand Down Expand Up @@ -759,7 +759,6 @@ def create_optimizer(
eps=optimizer_config.eps if optimizer_config.eps is not None else 1e-8,
weight_decay=optimizer_config.weight_decay if optimizer_config.weight_decay is not None else 0,
log_every=optimizer_config.log_every if optimizer_config.log_every is not None else 0,
decouple=optimizer_config.decouple if optimizer_config.decouple is not None else False,
use_bias_correction=optimizer_config.use_bias_correction if optimizer_config.use_bias_correction is not None else False,
d0=optimizer_config.d0 if optimizer_config.d0 is not None else 1e-6,
growth_rate=optimizer_config.growth_rate if optimizer_config.growth_rate is not None else float('inf'),
Expand Down Expand Up @@ -818,17 +817,44 @@ def create_optimizer(
params=parameters,
lr=config.learning_rate,
betas=(optimizer_config.beta1 if optimizer_config.beta1 is not None else 0.9,
optimizer_config.beta2 if optimizer_config.beta2 is not None else 0.999),
beta3=optimizer_config.beta3 if optimizer_config.beta3 is not None else None,
optimizer_config.beta2 if optimizer_config.beta2 is not None else 0.999,
beta3=optimizer_config.beta3 if optimizer_config.beta3 is not None else None),
eps=optimizer_config.eps if optimizer_config.eps is not None else 1e-8,
weight_decay=optimizer_config.weight_decay if optimizer_config.weight_decay is not None else 0,
use_bias_correction=optimizer_config.use_bias_correction if optimizer_config.use_bias_correction is not None else False,
safeguard_warmup=optimizer_config.safeguard_warmup if optimizer_config.safeguard_warmup is not None else False,
d0=optimizer_config.d0 if optimizer_config.d0 is not None else 1e-6,
d_coef=optimizer_config.d_coef if optimizer_config.d_coef is not None else 1.0,
growth_rate=optimizer_config.growth_rate if optimizer_config.growth_rate is not None else float('inf'),
fsdp_in_use=optimizer_config.fsdp_in_use if optimizer_config.fsdp_in_use is not None else False,
)

# PRODIGY_PLUS_SCHEDULE_FREE Optimizer
case Optimizer.PRODIGY_PLUS_SCHEDULE_FREE:
from prodigyplus.prodigy_plus_schedulefree import ProdigyPlusScheduleFree
optimizer = ProdigyPlusScheduleFree(
params=parameters,
lr=config.learning_rate,
betas=(optimizer_config.beta1 if optimizer_config.beta1 is not None else 0.9,
optimizer_config.beta2 if optimizer_config.beta2 is not None else 0.999,
beta3=optimizer_config.beta3 if optimizer_config.beta3 is not None else None),
eps=optimizer_config.eps if optimizer_config.eps is not None else 1e-8,
weight_decay=optimizer_config.weight_decay if optimizer_config.weight_decay is not None else 0,
decouple=optimizer_config.decouple if optimizer_config.decouple is not None else True,
use_bias_correction=optimizer_config.use_bias_correction if optimizer_config.use_bias_correction is not None else False,
safeguard_warmup=optimizer_config.safeguard_warmup if optimizer_config.safeguard_warmup is not None else False,
d0=optimizer_config.d0 if optimizer_config.d0 is not None else 1e-6,
d_coef=optimizer_config.d_coef if optimizer_config.d_coef is not None else 1.0,
prodigy_steps=optimizer_config.prodigy_steps if optimizer_config.prodigy_steps is not None else 0,
growth_rate=optimizer_config.growth_rate if optimizer_config.growth_rate is not None else float('inf'),
fsdp_in_use=optimizer_config.fsdp_in_use if optimizer_config.fsdp_in_use is not None else False,
split_groups=optimizer_config.split_groups if optimizer_config.split_groups is not None else True,
split_groups_mean=optimizer_config.split_groups_mean if optimizer_config.split_groups_mean is not None else True,
factored=optimizer_config.factored if optimizer_config.factored is not None else True,
fused_back_pass=optimizer_config.fused_back_pass if optimizer_config.fused_back_pass is not None else False,
use_stableadamw=optimizer_config.use_stableadamw if optimizer_config.use_stableadamw is not None else True,
use_muon_pp=optimizer_config.use_muon_pp if optimizer_config.use_muon_pp is not None else False,
use_cautious=optimizer_config.use_cautious if optimizer_config.use_cautious is not None else False,
use_adopt=optimizer_config.use_adopt if optimizer_config.use_adopt is not None else False,
)

# ADAFactor Optimizer
Expand Down
4 changes: 4 additions & 0 deletions modules/util/enum/Optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ class Optimizer(Enum):

# Prodigy
PRODIGY = 'PRODIGY'
PRODIGY_PLUS_SCHEDULE_FREE = 'PRODIGY_PLUS_SCHEDULE_FREE'

# ADAFACTOR
ADAFACTOR = 'ADAFACTOR'
Expand All @@ -73,13 +74,15 @@ def is_adaptive(self):
self.DADAPT_ADA_GRAD,
self.DADAPT_LION,
self.PRODIGY,
self.PRODIGY_PLUS_SCHEDULE_FREE,
]

@property
def is_schedule_free(self):
return self in [
self.SCHEDULE_FREE_ADAMW,
self.SCHEDULE_FREE_SGD,
self.PRODIGY_PLUS_SCHEDULE_FREE,
]

def supports_fused_back_pass(self):
Expand All @@ -88,6 +91,7 @@ def supports_fused_back_pass(self):
Optimizer.CAME,
Optimizer.ADAM,
Optimizer.ADAMW,
Optimizer.PRODIGY_PLUS_SCHEDULE_FREE,
]

# Small helper for adjusting learning rates to adaptive optimizers.
Expand Down
21 changes: 21 additions & 0 deletions modules/util/optimizer_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,6 +270,27 @@ def init_model_parameters(
"growth_rate": float('inf'),
"fsdp_in_use": False,
},
Optimizer.PRODIGY_PLUS_SCHEDULE_FREE: {
"beta1": 0.9,
"beta2": 0.999,
"beta3": None,
"eps": 1e-8,
"weight_decay": 0,
"use_bias_correction": False,
"safeguard_warmup": False,
"d0": 1e-6,
"d_coef": 1.0,
"prodigy_steps": 0,
"growth_rate": float('inf'),
"split_groups": True,
"split_groups_mean": True,
"factored": True,
"fused_back_pass": False,
"use_stableadamw": True,
"use_muon_pp": False,
"use_cautious": False,
"use_adopt": False,
},
Optimizer.DADAPT_ADA_GRAD: {
"momentum": 0,
"log_every": 0,
Expand Down
1 change: 1 addition & 0 deletions requirements-global.txt
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ lion-pytorch==0.2.2 # lion optimizer
prodigyopt==1.0 # prodigy optimizer
schedulefree==1.3.0 # schedule-free optimizers
pytorch_optimizer==3.3.0 # pytorch optimizers
prodigy-plus-schedule-free==1.8.0
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

1.9.0 got released rather recently. Only interface change is an extra parameter, factored_fp32 with a default of True.


# Profiling
scalene==1.5.45
Expand Down