Skip to content

Commit

Permalink
Merge pull request #328 from kozistr/update/wrapper
Browse files Browse the repository at this point in the history
[Update] proper property
  • Loading branch information
kozistr authored Jan 21, 2025
2 parents 5974aef + 4a1b3b7 commit cdbd9bc
Show file tree
Hide file tree
Showing 6 changed files with 42 additions and 34 deletions.
9 changes: 9 additions & 0 deletions docs/changelogs/v3.3.5.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
### Change Log

### Fix

* Add the missing `state` property in `OrthoGrad` optimizer. (#326, #327)

### Contributions

thanks to @Vectorrent
38 changes: 19 additions & 19 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

20 changes: 8 additions & 12 deletions pytorch_optimizer/optimizer/orthograd.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,10 @@
from collections import defaultdict
from typing import Callable, Dict

import torch
from torch.optim import Optimizer

from pytorch_optimizer.base.optimizer import BaseOptimizer
from pytorch_optimizer.base.types import (
CLOSURE,
DEFAULTS,
LOSS,
OPTIMIZER_INSTANCE_OR_CLASS,
STATE,
)
from pytorch_optimizer.base.types import CLOSURE, DEFAULTS, LOSS, OPTIMIZER_INSTANCE_OR_CLASS


class OrthoGrad(BaseOptimizer):
Expand All @@ -27,8 +20,6 @@ def __init__(self, optimizer: OPTIMIZER_INSTANCE_OR_CLASS, **kwargs) -> None:
self._optimizer_step_post_hooks: Dict[int, Callable] = {}
self.eps: float = 1e-30

self.state: STATE = defaultdict(dict)

if isinstance(optimizer, Optimizer):
self.optimizer = optimizer
elif 'params' in kwargs:
Expand All @@ -46,8 +37,13 @@ def __str__(self) -> str:
def param_groups(self):
return self.optimizer.param_groups

def __getstate__(self):
return {'optimizer': self.optimizer}
@property
def state(self):
return self.optimizer.state

@torch.no_grad()
def zero_grad(self) -> None:
self.optimizer.zero_grad(set_to_none=True)

@torch.no_grad()
def reset(self):
Expand Down
2 changes: 1 addition & 1 deletion requirements-dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ platformdirs==4.3.6 ; python_version >= "3.8"
pluggy==1.5.0 ; python_version >= "3.8"
pytest-cov==5.0.0 ; python_version >= "3.8"
pytest==8.3.4 ; python_version >= "3.8"
ruff==0.9.1 ; python_version >= "3.8"
ruff==0.9.2 ; python_version >= "3.8"
setuptools==75.8.0 ; python_version >= "3.12"
sympy==1.12.1 ; python_version == "3.8"
sympy==1.13.1 ; python_version >= "3.9"
Expand Down
2 changes: 1 addition & 1 deletion tests/test_gradients.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def test_no_gradients(optimizer_name):
sphere_loss(p1 + p3).backward(create_graph=True)

optimizer.step(lambda: 0.1) # for AliG optimizer
if optimizer_name not in {'lookahead', 'trac'}:
if optimizer_name not in {'lookahead', 'trac', 'orthograd'}:
optimizer.zero_grad(set_to_none=True)


Expand Down
5 changes: 4 additions & 1 deletion tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,6 +265,9 @@ def test_cpu_offload_optimizer():

def test_orthograd_name():
optimizer = build_orthograd(Example().parameters())
optimizer.zero_grad()

_ = optimizer.param_groups
_ = optimizer.__getstate__()
_ = optimizer.state

assert str(optimizer).lower() == 'orthograd'

0 comments on commit cdbd9bc

Please sign in to comment.