diff --git a/jacobian/jacobian.py b/jacobian/jacobian.py index 06bf734..041edfb 100644 --- a/jacobian/jacobian.py +++ b/jacobian/jacobian.py @@ -51,7 +51,7 @@ def forward(self, x, y): # random properly-normalized vector for each sample v = self._random_vector(C=C,B=B) if x.is_cuda: - v = v.cuda() + v = v.to(x.device) Jv = self._jacobian_vector_product(y, x, v, create_graph=True) J2 += C*torch.norm(Jv)**2 / (num_proj*B) R = (1/2)*J2