Skip to content

Commit

Permalink
fix: reset
Browse files Browse the repository at this point in the history
  • Loading branch information
kozistr committed Aug 7, 2024
1 parent 4f4d359 commit e94290f
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions pytorch_optimizer/optimizer/trac.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ def state(self):
def reset(self):
device = self.param_groups[0]['params'][0].device

self.state = {
self.state['trac'] = {
'betas': torch.tensor(self.betas, device=device),
's': torch.zeros(len(self.betas), device=device),
'variance': torch.zeros(len(self.betas), device=device),
Expand All @@ -148,7 +148,7 @@ def reset(self):

for group in self.param_groups:
for p in group['params']:
self.state[p] = p.clone()
self.state['trac'][p] = p.clone()

@torch.no_grad()
def zero_grad(self) -> None:
Expand Down

0 comments on commit e94290f

Please sign in to comment.