diff --git a/docs/changelogs/v3.3.5.md b/docs/changelogs/v3.3.5.md new file mode 100644 index 00000000..bb645c9a --- /dev/null +++ b/docs/changelogs/v3.3.5.md @@ -0,0 +1,9 @@ +### Change Log + +### Fix + +* Add the missing `state` property in `OrthoGrad` optimizer. (#326, #327) + +### Contributions + +thanks to @Vectorrent diff --git a/poetry.lock b/poetry.lock index a9c98f32..ea04ec6d 100644 --- a/poetry.lock +++ b/poetry.lock @@ -589,29 +589,29 @@ testing = ["fields", "hunter", "process-tests", "pytest-xdist", "virtualenv"] [[package]] name = "ruff" -version = "0.9.1" +version = "0.9.2" description = "An extremely fast Python linter and code formatter, written in Rust." optional = false python-versions = ">=3.7" files = [ - {file = "ruff-0.9.1-py3-none-linux_armv6l.whl", hash = "sha256:84330dda7abcc270e6055551aca93fdde1b0685fc4fd358f26410f9349cf1743"}, - {file = "ruff-0.9.1-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:3cae39ba5d137054b0e5b472aee3b78a7c884e61591b100aeb544bcd1fc38d4f"}, - {file = "ruff-0.9.1-py3-none-macosx_11_0_arm64.whl", hash = "sha256:50c647ff96f4ba288db0ad87048257753733763b409b2faf2ea78b45c8bb7fcb"}, - {file = "ruff-0.9.1-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f0c8b149e9c7353cace7d698e1656ffcf1e36e50f8ea3b5d5f7f87ff9986a7ca"}, - {file = "ruff-0.9.1-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:beb3298604540c884d8b282fe7625651378e1986c25df51dec5b2f60cafc31ce"}, - {file = "ruff-0.9.1-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:39d0174ccc45c439093971cc06ed3ac4dc545f5e8bdacf9f067adf879544d969"}, - {file = "ruff-0.9.1-py3-none-manylinux_2_17_ppc64.manylinux2014_ppc64.whl", hash = "sha256:69572926c0f0c9912288915214ca9b2809525ea263603370b9e00bed2ba56dbd"}, - {file = "ruff-0.9.1-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:937267afce0c9170d6d29f01fcd1f4378172dec6760a9f4dface48cdabf9610a"}, - {file = "ruff-0.9.1-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:186c2313de946f2c22bdf5954b8dd083e124bcfb685732cfb0beae0c47233d9b"}, - {file = "ruff-0.9.1-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3f94942a3bb767675d9a051867c036655fe9f6c8a491539156a6f7e6b5f31831"}, - {file = "ruff-0.9.1-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:728d791b769cc28c05f12c280f99e8896932e9833fef1dd8756a6af2261fd1ab"}, - {file = "ruff-0.9.1-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:2f312c86fb40c5c02b44a29a750ee3b21002bd813b5233facdaf63a51d9a85e1"}, - {file = "ruff-0.9.1-py3-none-musllinux_1_2_i686.whl", hash = "sha256:ae017c3a29bee341ba584f3823f805abbe5fe9cd97f87ed07ecbf533c4c88366"}, - {file = "ruff-0.9.1-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:5dc40a378a0e21b4cfe2b8a0f1812a6572fc7b230ef12cd9fac9161aa91d807f"}, - {file = "ruff-0.9.1-py3-none-win32.whl", hash = "sha256:46ebf5cc106cf7e7378ca3c28ce4293b61b449cd121b98699be727d40b79ba72"}, - {file = "ruff-0.9.1-py3-none-win_amd64.whl", hash = "sha256:342a824b46ddbcdddd3abfbb332fa7fcaac5488bf18073e841236aadf4ad5c19"}, - {file = "ruff-0.9.1-py3-none-win_arm64.whl", hash = "sha256:1cd76c7f9c679e6e8f2af8f778367dca82b95009bc7b1a85a47f1521ae524fa7"}, - {file = "ruff-0.9.1.tar.gz", hash = "sha256:fd2b25ecaf907d6458fa842675382c8597b3c746a2dde6717fe3415425df0c17"}, + {file = "ruff-0.9.2-py3-none-linux_armv6l.whl", hash = "sha256:80605a039ba1454d002b32139e4970becf84b5fee3a3c3bf1c2af6f61a784347"}, + {file = "ruff-0.9.2-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:b9aab82bb20afd5f596527045c01e6ae25a718ff1784cb92947bff1f83068b00"}, + {file = "ruff-0.9.2-py3-none-macosx_11_0_arm64.whl", hash = "sha256:fbd337bac1cfa96be615f6efcd4bc4d077edbc127ef30e2b8ba2a27e18c054d4"}, + {file = "ruff-0.9.2-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:82b35259b0cbf8daa22a498018e300b9bb0174c2bbb7bcba593935158a78054d"}, + {file = "ruff-0.9.2-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:8b6a9701d1e371bf41dca22015c3f89769da7576884d2add7317ec1ec8cb9c3c"}, + {file = "ruff-0.9.2-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:9cc53e68b3c5ae41e8faf83a3b89f4a5d7b2cb666dff4b366bb86ed2a85b481f"}, + {file = "ruff-0.9.2-py3-none-manylinux_2_17_ppc64.manylinux2014_ppc64.whl", hash = "sha256:8efd9da7a1ee314b910da155ca7e8953094a7c10d0c0a39bfde3fcfd2a015684"}, + {file = "ruff-0.9.2-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:3292c5a22ea9a5f9a185e2d131dc7f98f8534a32fb6d2ee7b9944569239c648d"}, + {file = "ruff-0.9.2-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:1a605fdcf6e8b2d39f9436d343d1f0ff70c365a1e681546de0104bef81ce88df"}, + {file = "ruff-0.9.2-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c547f7f256aa366834829a08375c297fa63386cbe5f1459efaf174086b564247"}, + {file = "ruff-0.9.2-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:d18bba3d3353ed916e882521bc3e0af403949dbada344c20c16ea78f47af965e"}, + {file = "ruff-0.9.2-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:b338edc4610142355ccf6b87bd356729b62bf1bc152a2fad5b0c7dc04af77bfe"}, + {file = "ruff-0.9.2-py3-none-musllinux_1_2_i686.whl", hash = "sha256:492a5e44ad9b22a0ea98cf72e40305cbdaf27fac0d927f8bc9e1df316dcc96eb"}, + {file = "ruff-0.9.2-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:af1e9e9fe7b1f767264d26b1075ac4ad831c7db976911fa362d09b2d0356426a"}, + {file = "ruff-0.9.2-py3-none-win32.whl", hash = "sha256:71cbe22e178c5da20e1514e1e01029c73dc09288a8028a5d3446e6bba87a5145"}, + {file = "ruff-0.9.2-py3-none-win_amd64.whl", hash = "sha256:c5e1d6abc798419cf46eed03f54f2e0c3adb1ad4b801119dedf23fcaf69b55b5"}, + {file = "ruff-0.9.2-py3-none-win_arm64.whl", hash = "sha256:a1b63fa24149918f8b37cef2ee6fff81f24f0d74b6f0bdc37bc3e1f2143e41c6"}, + {file = "ruff-0.9.2.tar.gz", hash = "sha256:b5eceb334d55fae5f316f783437392642ae18e16dcf4f1858d55d3c2a0f8f5d0"}, ] [[package]] diff --git a/pytorch_optimizer/optimizer/orthograd.py b/pytorch_optimizer/optimizer/orthograd.py index a291b9dd..3467e33f 100644 --- a/pytorch_optimizer/optimizer/orthograd.py +++ b/pytorch_optimizer/optimizer/orthograd.py @@ -1,17 +1,10 @@ -from collections import defaultdict from typing import Callable, Dict import torch from torch.optim import Optimizer from pytorch_optimizer.base.optimizer import BaseOptimizer -from pytorch_optimizer.base.types import ( - CLOSURE, - DEFAULTS, - LOSS, - OPTIMIZER_INSTANCE_OR_CLASS, - STATE, -) +from pytorch_optimizer.base.types import CLOSURE, DEFAULTS, LOSS, OPTIMIZER_INSTANCE_OR_CLASS class OrthoGrad(BaseOptimizer): @@ -27,8 +20,6 @@ def __init__(self, optimizer: OPTIMIZER_INSTANCE_OR_CLASS, **kwargs) -> None: self._optimizer_step_post_hooks: Dict[int, Callable] = {} self.eps: float = 1e-30 - self.state: STATE = defaultdict(dict) - if isinstance(optimizer, Optimizer): self.optimizer = optimizer elif 'params' in kwargs: @@ -46,8 +37,13 @@ def __str__(self) -> str: def param_groups(self): return self.optimizer.param_groups - def __getstate__(self): - return {'optimizer': self.optimizer} + @property + def state(self): + return self.optimizer.state + + @torch.no_grad() + def zero_grad(self) -> None: + self.optimizer.zero_grad(set_to_none=True) @torch.no_grad() def reset(self): diff --git a/requirements-dev.txt b/requirements-dev.txt index e7d5c4ab..7eb11ec6 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -22,7 +22,7 @@ platformdirs==4.3.6 ; python_version >= "3.8" pluggy==1.5.0 ; python_version >= "3.8" pytest-cov==5.0.0 ; python_version >= "3.8" pytest==8.3.4 ; python_version >= "3.8" -ruff==0.9.1 ; python_version >= "3.8" +ruff==0.9.2 ; python_version >= "3.8" setuptools==75.8.0 ; python_version >= "3.12" sympy==1.12.1 ; python_version == "3.8" sympy==1.13.1 ; python_version >= "3.9" diff --git a/tests/test_gradients.py b/tests/test_gradients.py index a8e27372..bc159aaf 100644 --- a/tests/test_gradients.py +++ b/tests/test_gradients.py @@ -37,7 +37,7 @@ def test_no_gradients(optimizer_name): sphere_loss(p1 + p3).backward(create_graph=True) optimizer.step(lambda: 0.1) # for AliG optimizer - if optimizer_name not in {'lookahead', 'trac'}: + if optimizer_name not in {'lookahead', 'trac', 'orthograd'}: optimizer.zero_grad(set_to_none=True) diff --git a/tests/test_utils.py b/tests/test_utils.py index 462c84ee..19b00069 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -265,6 +265,9 @@ def test_cpu_offload_optimizer(): def test_orthograd_name(): optimizer = build_orthograd(Example().parameters()) + optimizer.zero_grad() + _ = optimizer.param_groups - _ = optimizer.__getstate__() + _ = optimizer.state + assert str(optimizer).lower() == 'orthograd'