From 0cbbf2df726e9c4dee555c653d7313bfee615afc Mon Sep 17 00:00:00 2001 From: Franck Mamalet Date: Mon, 1 Jul 2024 15:10:23 +0200 Subject: [PATCH] paranthesis tricks bjork --- deel/torchlip/normalizers.py | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) diff --git a/deel/torchlip/normalizers.py b/deel/torchlip/normalizers.py index a4a2403..a37e60e 100644 --- a/deel/torchlip/normalizers.py +++ b/deel/torchlip/normalizers.py @@ -70,10 +70,17 @@ def bjorck_normalization( shape = w.shape cout = w.size(0) w_mat = w.reshape(cout, -1) - for i in range(niter): - w_mat = (1.0 + beta) * w_mat - beta * torch.mm( - w_mat, torch.mm(w_mat.t(), w_mat) - ) + + if w_mat.shape[0]>w_mat.shape[1]: + for i in range(niter): + w_mat = (1.0 + beta) * w_mat - beta * torch.mm( + w_mat, torch.mm(w_mat.t(), w_mat) + ) + else: + for i in range(niter): + w_mat = (1.0 + beta) * w_mat - beta * torch.mm( + torch.mm(w_mat, w_mat.t()), w_mat + ) w = w_mat.reshape(shape) return w