diff --git a/quaterion/utils/utils.py b/quaterion/utils/utils.py index 983d4906..faefd8a1 100644 --- a/quaterion/utils/utils.py +++ b/quaterion/utils/utils.py @@ -1,6 +1,7 @@ from typing import Iterable, Optional, Sized, Union import torch +import torch.nn.functional as F import tqdm from torch.utils.data import Dataset @@ -259,6 +260,4 @@ def l2_norm(inputs: torch.Tensor, dim: int = 0) -> torch.Tensor: Returns: torch.Tensor: L2-normalized tensor """ - outputs = inputs / torch.norm(inputs, 2, dim, True) - - return outputs + return F.normalize(inputs, p=2, dim=dim)