Skip to content

Commit

Permalink
test normalizer all close value
Browse files Browse the repository at this point in the history
  • Loading branch information
Franck Mamalet committed Oct 9, 2024
1 parent f1688dc commit 576de52
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 10 deletions.
38 changes: 28 additions & 10 deletions tests/test_normalizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ def test_kernel_svd(kernel_shape):
)
# Test if kernel is normalized by sigma
np.testing.assert_allclose(
np.reshape(W_bar, kernel.shape), kernel / sigmas_svd[0], 1e-2, 0
np.reshape(W_bar, kernel.shape), kernel / sigmas_svd[0], atol=1e-2
)


Expand Down Expand Up @@ -211,7 +211,9 @@ def test_kernel_conv_svd(kernel_shape, strides):
sigma, SVmax, 2, "test failed with kernel_shape " + str(kernel.shape)
)
# Test if kernel is normalized by sigma
np.testing.assert_allclose(np.reshape(W_bar, kernel.shape), kernel / SVmax, 1e-2, 0)
np.testing.assert_allclose(
np.reshape(W_bar, kernel.shape), kernel / SVmax, atol=1e-2
)


@pytest.mark.parametrize(
Expand Down Expand Up @@ -244,7 +246,9 @@ def test_bjorck_normalization(kernel_shape):
compute_uv=False,
)
# Test sigma is close to the one computed with svd first run @ 1e-1
np.testing.assert_allclose(sigmas_wbar_svd, np.ones(sigmas_wbar_svd.shape), 1e-2, 0)
np.testing.assert_allclose(
sigmas_wbar_svd, np.ones(sigmas_wbar_svd.shape), atol=1e-2
)
# Test W_bar is reshaped correctly
np.testing.assert_equal(wbar.shape, (np.prod(kernel.shape[:-1]), kernel.shape[-1]))

Expand All @@ -256,7 +260,9 @@ def test_bjorck_normalization(kernel_shape):
full_matrices=False,
compute_uv=False,
)
np.testing.assert_allclose(sigmas_wbar_svd, np.ones(sigmas_wbar_svd.shape), 1e-4, 0)
np.testing.assert_allclose(
sigmas_wbar_svd, np.ones(sigmas_wbar_svd.shape), atol=1e-4
)


@pytest.mark.parametrize(
Expand Down Expand Up @@ -306,7 +312,9 @@ def test_reshaped_kernel_orthogonalization(kernel_shape):
compute_uv=False,
)
# Test if SVs of W_bar are close to one
np.testing.assert_allclose(sigmas_wbar_svd, np.ones(sigmas_wbar_svd.shape), 1e-2, 0)
np.testing.assert_allclose(
sigmas_wbar_svd, np.ones(sigmas_wbar_svd.shape), atol=1e-2
)


@pytest.mark.skipif(
Expand Down Expand Up @@ -340,14 +348,18 @@ def test_bjorck_norm():
x = np.random.rand(2)
x = uft.to_tensor(x)
m(x)
np.testing.assert_equal(uft.to_numpy(w1), uft.to_numpy(uft.get_layer_weights(m)))
np.testing.assert_allclose(
uft.to_numpy(w1), uft.to_numpy(uft.get_layer_weights(m)), atol=1e-5
)

# remove the parametrization
uft.get_instance_framework(remove_bjorck_norm, {"module": m}) # (m)
uft.check_parametrization(m, is_parametrized=False)
# assert not hasattr(m, "parametrizations")
# assert isinstance(m.weight, torch.nn.Parameter)
np.testing.assert_equal(uft.to_numpy(w1), uft.to_numpy(uft.get_layer_weights(m)))
np.testing.assert_allclose(
uft.to_numpy(w1), uft.to_numpy(uft.get_layer_weights(m)), atol=1e-5
)


@pytest.mark.skipif(
Expand Down Expand Up @@ -379,12 +391,16 @@ def test_frobenius_norm():
x = np.random.rand(2)
x = uft.to_tensor(x)
m(x)
np.testing.assert_equal(uft.to_numpy(w1), uft.to_numpy(uft.get_layer_weights(m)))
np.testing.assert_allclose(
uft.to_numpy(w1), uft.to_numpy(uft.get_layer_weights(m)), atol=1e-5
)

# remove the parametrization
uft.get_instance_framework(remove_frobenius_norm, {"module": m})
uft.check_parametrization(m, is_parametrized=False)
np.testing.assert_equal(uft.to_numpy(w1), uft.to_numpy(uft.get_layer_weights(m)))
np.testing.assert_allclose(
uft.to_numpy(w1), uft.to_numpy(uft.get_layer_weights(m)), atol=1e-5
)


@pytest.mark.skipif(
Expand Down Expand Up @@ -459,4 +475,6 @@ def test_lconv_norm():
uft.get_instance_framework(remove_lconv_norm, {"module": m})
uft.check_parametrization(m, is_parametrized=False)
# assert isinstance(m.weight, torch.nn.Parameter)
np.testing.assert_equal(uft.to_numpy(w1), uft.to_numpy(uft.get_layer_weights(m)))
np.testing.assert_allclose(
uft.to_numpy(w1), uft.to_numpy(uft.get_layer_weights(m)), atol=1e-7
)
1 change: 1 addition & 0 deletions tests/utils_framework.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,7 @@
"Loss",
]


# not implemented
def module_Unavailable(**kwargs):
return None
Expand Down

0 comments on commit 576de52

Please sign in to comment.