From 2371ddf0bfd860ac29e0f6e24a238d342495a330 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Karahan=20Sar=C4=B1ta=C5=9F?= <44376034+KarahanS@users.noreply.github.com> Date: Tue, 4 Apr 2023 10:59:30 +0300 Subject: [PATCH] switch from torch.norm to torch.nn.functional.normalize #203 (#206) --- quaterion/utils/utils.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/quaterion/utils/utils.py b/quaterion/utils/utils.py index 983d4906..faefd8a1 100644 --- a/quaterion/utils/utils.py +++ b/quaterion/utils/utils.py @@ -1,6 +1,7 @@ from typing import Iterable, Optional, Sized, Union import torch +import torch.nn.functional as F import tqdm from torch.utils.data import Dataset @@ -259,6 +260,4 @@ def l2_norm(inputs: torch.Tensor, dim: int = 0) -> torch.Tensor: Returns: torch.Tensor: L2-normalized tensor """ - outputs = inputs / torch.norm(inputs, 2, dim, True) - - return outputs + return F.normalize(inputs, p=2, dim=dim)