From 8c2a6f5071ea620121888cf4d1a1561fef1e0f45 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Nicholas=20Kr=C3=A4mer?= Date: Wed, 10 Jan 2024 13:35:06 +0100 Subject: [PATCH] Make SLQ integrands behave correctly for arbitrary input vectors (#174) * Make integrand_slq_spd behave correctly for arbitrary input vectors, not just unit-second-moment input vectors * Make integrand_slq_product behave correctly, too --- matfree/slq.py | 13 ++++++------ tests/test_slq/test_logdet_product.py | 30 +++++++++++++++++++++++++++ tests/test_slq/test_logdet_spd.py | 28 +++++++++++++++++++++++++ 3 files changed, 64 insertions(+), 7 deletions(-) diff --git a/matfree/slq.py b/matfree/slq.py index 0b262dc..3cfac3f 100644 --- a/matfree/slq.py +++ b/matfree/slq.py @@ -25,7 +25,8 @@ def integrand_slq_spd(matfun, order, matvec, /): def quadform(v0, *parameters): v0_flat, v_unflatten = tree_util.ravel_pytree(v0) - v0_flat /= linalg.vector_norm(v0_flat) + length = linalg.vector_norm(v0_flat) + v0_flat /= length def matvec_flat(v_flat): v = v_unflatten(v_flat) @@ -50,10 +51,8 @@ def matvec_flat(v_flat): # Since Q orthogonal (orthonormal) to v0, Q v = Q[0], # and therefore (Q v)^T f(D) (Qv) = Q[0] * f(diag) * Q[0] - (dim,) = v0_flat.shape - fx_eigvals = func.vmap(matfun)(eigvals) - return dim * linalg.vecdot(eigvecs[0, :], fx_eigvals * eigvecs[0, :]) + return length**2 * linalg.vecdot(eigvecs[0, :], fx_eigvals * eigvecs[0, :]) return quadform @@ -86,7 +85,8 @@ def integrand_slq_product(matfun, depth, matvec, vecmat, /): def quadform(v0, *parameters): v0_flat, v_unflatten = tree_util.ravel_pytree(v0) - v0_flat /= linalg.vector_norm(v0_flat) + length = linalg.vector_norm(v0_flat) + v0_flat /= length def matvec_flat(v_flat): v = v_unflatten(v_flat) @@ -115,10 +115,9 @@ def vecmat_flat(w_flat): # Since Q orthogonal (orthonormal) to v0, Q v = Q[0], # and therefore (Q v)^T f(D) (Qv) = Q[0] * f(diag) * Q[0] - _, ncols = matrix_shape eigvals, eigvecs = S**2, Vt.T fx_eigvals = func.vmap(matfun)(eigvals) - return ncols * linalg.vecdot(eigvecs[0, :], fx_eigvals * eigvecs[0, :]) + return length**2 * linalg.vecdot(eigvecs[0, :], fx_eigvals * eigvecs[0, :]) return quadform diff --git a/tests/test_slq/test_logdet_product.py b/tests/test_slq/test_logdet_product.py index 8b765d3..2874496 100644 --- a/tests/test_slq/test_logdet_product.py +++ b/tests/test_slq/test_logdet_product.py @@ -38,3 +38,33 @@ def vecmat(x): expected = linalg.slogdet(A.T @ A)[1] print_if_assert_fails = ("error", np.abs(received - expected), "target:", expected) assert np.allclose(received, expected, atol=1e-2, rtol=1e-2), print_if_assert_fails + + +@testing.parametrize("n", [50]) +# usually: ~1.5 * num_significant_eigvals. +# But logdet seems to converge sooo much faster. +def test_logdet_product_exact_for_full_order_lanczos(n): + r"""Computing v^\top f(A^\top @ A) v with max-order Lanczos is exact for _any_ v.""" + # Construct a (numerically nice) matrix + singular_values = np.sqrt(np.arange(1.0, 1.0 + n, step=1.0)) + A = test_util.asymmetric_matrix_from_singular_values( + singular_values, nrows=n, ncols=n + ) + + # Set up max-order Lanczos approximation inside SLQ for the matrix-logarithm + order = n - 1 + integrand = slq.integrand_logdet_product(order, lambda v: A @ v, lambda v: v @ A) + + # Construct a vector without that does not have expected 2-norm equal to "dim" + x = prng.normal(prng.prng_key(seed=1), shape=(n,)) + 1 + + # Compute v^\top @ log(A) @ v via Lanczos + received = integrand(x) + + # Compute the "true" value of v^\top @ log(A) @ v via eigenvalues + eigvals, eigvecs = linalg.eigh(A.T @ A) + logA = eigvecs @ linalg.diagonal_matrix(np.log(eigvals)) @ eigvecs.T + expected = x.T @ logA @ x + + # They should be identical + assert np.allclose(received, expected) diff --git a/tests/test_slq/test_logdet_spd.py b/tests/test_slq/test_logdet_spd.py index fc43b5e..f03b2c1 100644 --- a/tests/test_slq/test_logdet_spd.py +++ b/tests/test_slq/test_logdet_spd.py @@ -36,3 +36,31 @@ def matvec(x): expected = linalg.slogdet(A)[1] print_if_assert_fails = ("error", np.abs(received - expected), "target:", expected) assert np.allclose(received, expected, atol=1e-2, rtol=1e-2), print_if_assert_fails + + +@testing.parametrize("n", [50]) +# usually: ~1.5 * num_significant_eigvals. +# But logdet seems to converge sooo much faster. +def test_logdet_spd_exact_for_full_order_lanczos(n): + r"""Computing v^\top f(A) v with max-order Lanczos should be exact for _any_ v.""" + # Construct a (numerically nice) matrix + eigvals = np.arange(1.0, 1.0 + n, step=1.0) + A = test_util.symmetric_matrix_from_eigenvalues(eigvals) + + # Set up max-order Lanczos approximation inside SLQ for the matrix-logarithm + order = n - 1 + integrand = slq.integrand_logdet_spd(order, lambda v: A @ v) + + # Construct a vector without that does not have expected 2-norm equal to "dim" + x = prng.normal(prng.prng_key(seed=1), shape=(n,)) + 10 + + # Compute v^\top @ log(A) @ v via Lanczos + received = integrand(x) + + # Compute the "true" value of v^\top @ log(A) @ v via eigenvalues + eigvals, eigvecs = linalg.eigh(A) + logA = eigvecs @ linalg.diagonal_matrix(np.log(eigvals)) @ eigvecs.T + expected = x.T @ logA @ x + + # They should be identical + assert np.allclose(received, expected)