From d200d9977b4ebda395483eb811def247fa7bc4f4 Mon Sep 17 00:00:00 2001 From: kozistr Date: Tue, 13 Aug 2024 19:39:41 +0900 Subject: [PATCH 1/3] update: disable coverage --- pytorch_optimizer/optimizer/prodigy.py | 1 + 1 file changed, 1 insertion(+) diff --git a/pytorch_optimizer/optimizer/prodigy.py b/pytorch_optimizer/optimizer/prodigy.py index ccef49c5..22732037 100644 --- a/pytorch_optimizer/optimizer/prodigy.py +++ b/pytorch_optimizer/optimizer/prodigy.py @@ -111,6 +111,7 @@ def step(self, closure: CLOSURE = None) -> LOSS: if 'd_numerator' not in group: group['d_numerator'] = torch.tensor([0.0], device=device) elif group['d_numerator'].device != device: + # pragma: no cover group['d_numerator'] = group['d_numerator'].to(device) d_numerator = group['d_numerator'] From 537a0d7063448efd7b2ab4bb9ddd2c3967378d40 Mon Sep 17 00:00:00 2001 From: kozistr Date: Tue, 13 Aug 2024 19:42:28 +0900 Subject: [PATCH 2/3] update: code --- tests/test_optimizers.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/tests/test_optimizers.py b/tests/test_optimizers.py index aaf9024e..b862661f 100644 --- a/tests/test_optimizers.py +++ b/tests/test_optimizers.py @@ -76,8 +76,7 @@ def _closure() -> float: for _ in range(iterations): optimizer.zero_grad() - y_pred = model(x_data) - loss = loss_fn(y_pred, y_data) + loss = loss_fn(model(x_data), y_data) if init_loss == np.inf: init_loss = loss From 3f38e3fb871a030d890d545db8b259a5563808fa Mon Sep 17 00:00:00 2001 From: kozistr Date: Tue, 13 Aug 2024 19:46:09 +0900 Subject: [PATCH 3/3] update: code --- pytorch_optimizer/optimizer/prodigy.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/pytorch_optimizer/optimizer/prodigy.py b/pytorch_optimizer/optimizer/prodigy.py index 22732037..ba0bb7b0 100644 --- a/pytorch_optimizer/optimizer/prodigy.py +++ b/pytorch_optimizer/optimizer/prodigy.py @@ -111,8 +111,7 @@ def step(self, closure: CLOSURE = None) -> LOSS: if 'd_numerator' not in group: group['d_numerator'] = torch.tensor([0.0], device=device) elif group['d_numerator'].device != device: - # pragma: no cover - group['d_numerator'] = group['d_numerator'].to(device) + group['d_numerator'] = group['d_numerator'].to(device) # pragma: no cover d_numerator = group['d_numerator'] d_numerator.mul_(beta3)