Skip to content

Commit

Permalink
fix crashing optim
Browse files Browse the repository at this point in the history
  • Loading branch information
Vectorrent committed Jan 22, 2025
1 parent cdbd9bc commit 2abee6e
Showing 1 changed file with 16 additions and 8 deletions.
24 changes: 16 additions & 8 deletions pytorch_optimizer/optimizer/orthograd.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,17 @@
from collections import defaultdict
from typing import Callable, Dict

import torch
from torch.optim import Optimizer

from pytorch_optimizer.base.optimizer import BaseOptimizer
from pytorch_optimizer.base.types import CLOSURE, DEFAULTS, LOSS, OPTIMIZER_INSTANCE_OR_CLASS
from pytorch_optimizer.base.types import (
CLOSURE,
DEFAULTS,
LOSS,
OPTIMIZER_INSTANCE_OR_CLASS,
STATE,
)


class OrthoGrad(BaseOptimizer):
Expand All @@ -18,8 +25,14 @@ class OrthoGrad(BaseOptimizer):
def __init__(self, optimizer: OPTIMIZER_INSTANCE_OR_CLASS, **kwargs) -> None:
self._optimizer_step_pre_hooks: Dict[int, Callable] = {}
self._optimizer_step_post_hooks: Dict[int, Callable] = {}
self._optimizer_state_dict_pre_hooks: Dict[int, Callable] = {}
self._optimizer_state_dict_post_hooks: Dict[int, Callable] = {}
self._optimizer_load_state_dict_pre_hooks: Dict[int, Callable] = {}
self._optimizer_load_state_dict_post_hooks: Dict[int, Callable] = {}
self.eps: float = 1e-30

self.state: STATE = defaultdict(dict)

if isinstance(optimizer, Optimizer):
self.optimizer = optimizer
elif 'params' in kwargs:
Expand All @@ -37,13 +50,8 @@ def __str__(self) -> str:
def param_groups(self):
return self.optimizer.param_groups

@property
def state(self):
return self.optimizer.state

@torch.no_grad()
def zero_grad(self) -> None:
self.optimizer.zero_grad(set_to_none=True)
def __getstate__(self):
return {'optimizer': self.optimizer}

Check warning on line 54 in pytorch_optimizer/optimizer/orthograd.py

View check run for this annotation

Codecov / codecov/patch

pytorch_optimizer/optimizer/orthograd.py#L54

Added line #L54 was not covered by tests

@torch.no_grad()
def reset(self):
Expand Down

0 comments on commit 2abee6e

Please sign in to comment.