diff --git a/opacus/optimizers/optimizer.py b/opacus/optimizers/optimizer.py index 2522a9de..53eb3e50 100644 --- a/opacus/optimizers/optimizer.py +++ b/opacus/optimizers/optimizer.py @@ -396,7 +396,9 @@ def clip_and_accumulate(self): if len(self.grad_samples[0]) == 0: # Empty batch - per_sample_clip_factor = torch.zeros((0,)) + per_sample_clip_factor = torch.zeros( + (0,), device=self.grad_samples[0].device + ) else: per_param_norms = [ g.reshape(len(g), -1).norm(2, dim=-1) for g in self.grad_samples