Skip to content

Commit

Permalink
Merge pull request #32 from francois-rozet/29-warmup
Browse files Browse the repository at this point in the history
Fix learning rate warmup steps
  • Loading branch information
ClashLuke authored Jan 8, 2025
2 parents cfc7f34 + 88eb504 commit 651cdc6
Show file tree
Hide file tree
Showing 3 changed files with 3 additions and 14 deletions.
4 changes: 2 additions & 2 deletions heavyball/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ class ForeachSOAP(C.BaseOpt):
def __init__(self, params, lr: float = 3e-3, betas=(0.9, 0.95), shampoo_beta: float = 0.95, eps: float = 1e-8,
weight_decay: float = 0.01, precondition_frequency: int = 2, max_precond_dim: int = 2048, #
merge_dims: bool = True, precondition_1d: bool = False, normalize_grads: bool = False,
data_format: str = "channels_first", correct_bias: bool = True, warmup_steps: int = 1,
data_format: str = "channels_first", correct_bias: bool = True, warmup_steps: int = 0,
split: bool = False, foreach: bool = True, mars: bool = False, caution: bool = False,
mars_gamma: float = 0.0025, palm: bool = C.use_default, precond_scheduler=(1 / 3, 9),
beta2_scale: float = 0.8, use_precond_schedule: bool = C.use_default,
Expand Down Expand Up @@ -162,7 +162,7 @@ class ForeachPSGDKron(C.BaseOpt):

def __init__(self, params, lr=0.001, beta=0.9, weight_decay=0.0, preconditioner_update_probability=None,
max_size_triangular=2048, min_ndim_triangular=2, memory_save_mode=None,
momentum_into_precond_update=True, warmup_steps: int = 1, merge_dims: bool = False,
momentum_into_precond_update=True, warmup_steps: int = 0, merge_dims: bool = False,
split: bool = False, store_triu_as_line: bool = True, foreach: bool = True, q_dtype='float32',
stochastic_schedule: bool = True, storage_dtype: str = 'float32', mars: bool = False,
caution: bool = False, mars_gamma: float = 0.0025, delayed: Optional[bool] = C.use_default,
Expand Down
7 changes: 1 addition & 6 deletions heavyball/chainable.py
Original file line number Diff line number Diff line change
Expand Up @@ -464,12 +464,7 @@ def _step(self, group):
break

group['step'] = state['step'] = step = step + 1

if group['warmup_steps'] and step < group['warmup_steps']:
group['prev_lr'] = group['lr'] = group['base_lr'] * step / group['warmup_steps']

else:
group['prev_lr'] = group['lr'] = group['base_lr']
group['prev_lr'] = group['lr'] = group['base_lr'] * step / max(step, group['warmup_steps'] + 1)

if not group['foreach'] or len(p) == 1:
for param, grad in zip(p, g):
Expand Down
6 changes: 0 additions & 6 deletions heavyball/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,12 +54,6 @@ def _fn(*args, **kwargs):
einsum_base = string.ascii_lowercase + string.ascii_uppercase


def warmup(lr: float, step: int, warmup_steps: int):
if step >= warmup_steps: # if instead of min to guard against 0 div
return lr
return lr * step / warmup_steps


@decorator_knowngood
def _compilable_schedule_free_(p: List[Tensor], z: List[Tensor], ckp1: Tensor, update: List[Tensor], lr: Tensor,
beta1: Tensor, decay: float, grad: List[Tensor], caution):
Expand Down

0 comments on commit 651cdc6

Please sign in to comment.