Skip to content

Commit

Permalink
Make SLQ integrands behave correctly for arbitrary input vectors (#174)
Browse files Browse the repository at this point in the history
* Make integrand_slq_spd behave correctly for arbitrary input vectors, not just unit-second-moment input vectors

* Make integrand_slq_product behave correctly, too
  • Loading branch information
pnkraemer authored Jan 10, 2024
1 parent 174af02 commit 8c2a6f5
Show file tree
Hide file tree
Showing 3 changed files with 64 additions and 7 deletions.
13 changes: 6 additions & 7 deletions matfree/slq.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,8 @@ def integrand_slq_spd(matfun, order, matvec, /):

def quadform(v0, *parameters):
v0_flat, v_unflatten = tree_util.ravel_pytree(v0)
v0_flat /= linalg.vector_norm(v0_flat)
length = linalg.vector_norm(v0_flat)
v0_flat /= length

def matvec_flat(v_flat):
v = v_unflatten(v_flat)
Expand All @@ -50,10 +51,8 @@ def matvec_flat(v_flat):

# Since Q orthogonal (orthonormal) to v0, Q v = Q[0],
# and therefore (Q v)^T f(D) (Qv) = Q[0] * f(diag) * Q[0]
(dim,) = v0_flat.shape

fx_eigvals = func.vmap(matfun)(eigvals)
return dim * linalg.vecdot(eigvecs[0, :], fx_eigvals * eigvecs[0, :])
return length**2 * linalg.vecdot(eigvecs[0, :], fx_eigvals * eigvecs[0, :])

return quadform

Expand Down Expand Up @@ -86,7 +85,8 @@ def integrand_slq_product(matfun, depth, matvec, vecmat, /):

def quadform(v0, *parameters):
v0_flat, v_unflatten = tree_util.ravel_pytree(v0)
v0_flat /= linalg.vector_norm(v0_flat)
length = linalg.vector_norm(v0_flat)
v0_flat /= length

def matvec_flat(v_flat):
v = v_unflatten(v_flat)
Expand Down Expand Up @@ -115,10 +115,9 @@ def vecmat_flat(w_flat):

# Since Q orthogonal (orthonormal) to v0, Q v = Q[0],
# and therefore (Q v)^T f(D) (Qv) = Q[0] * f(diag) * Q[0]
_, ncols = matrix_shape
eigvals, eigvecs = S**2, Vt.T
fx_eigvals = func.vmap(matfun)(eigvals)
return ncols * linalg.vecdot(eigvecs[0, :], fx_eigvals * eigvecs[0, :])
return length**2 * linalg.vecdot(eigvecs[0, :], fx_eigvals * eigvecs[0, :])

return quadform

Expand Down
30 changes: 30 additions & 0 deletions tests/test_slq/test_logdet_product.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,3 +38,33 @@ def vecmat(x):
expected = linalg.slogdet(A.T @ A)[1]
print_if_assert_fails = ("error", np.abs(received - expected), "target:", expected)
assert np.allclose(received, expected, atol=1e-2, rtol=1e-2), print_if_assert_fails


@testing.parametrize("n", [50])
# usually: ~1.5 * num_significant_eigvals.
# But logdet seems to converge sooo much faster.
def test_logdet_product_exact_for_full_order_lanczos(n):
r"""Computing v^\top f(A^\top @ A) v with max-order Lanczos is exact for _any_ v."""
# Construct a (numerically nice) matrix
singular_values = np.sqrt(np.arange(1.0, 1.0 + n, step=1.0))
A = test_util.asymmetric_matrix_from_singular_values(
singular_values, nrows=n, ncols=n
)

# Set up max-order Lanczos approximation inside SLQ for the matrix-logarithm
order = n - 1
integrand = slq.integrand_logdet_product(order, lambda v: A @ v, lambda v: v @ A)

# Construct a vector without that does not have expected 2-norm equal to "dim"
x = prng.normal(prng.prng_key(seed=1), shape=(n,)) + 1

# Compute v^\top @ log(A) @ v via Lanczos
received = integrand(x)

# Compute the "true" value of v^\top @ log(A) @ v via eigenvalues
eigvals, eigvecs = linalg.eigh(A.T @ A)
logA = eigvecs @ linalg.diagonal_matrix(np.log(eigvals)) @ eigvecs.T
expected = x.T @ logA @ x

# They should be identical
assert np.allclose(received, expected)
28 changes: 28 additions & 0 deletions tests/test_slq/test_logdet_spd.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,3 +36,31 @@ def matvec(x):
expected = linalg.slogdet(A)[1]
print_if_assert_fails = ("error", np.abs(received - expected), "target:", expected)
assert np.allclose(received, expected, atol=1e-2, rtol=1e-2), print_if_assert_fails


@testing.parametrize("n", [50])
# usually: ~1.5 * num_significant_eigvals.
# But logdet seems to converge sooo much faster.
def test_logdet_spd_exact_for_full_order_lanczos(n):
r"""Computing v^\top f(A) v with max-order Lanczos should be exact for _any_ v."""
# Construct a (numerically nice) matrix
eigvals = np.arange(1.0, 1.0 + n, step=1.0)
A = test_util.symmetric_matrix_from_eigenvalues(eigvals)

# Set up max-order Lanczos approximation inside SLQ for the matrix-logarithm
order = n - 1
integrand = slq.integrand_logdet_spd(order, lambda v: A @ v)

# Construct a vector without that does not have expected 2-norm equal to "dim"
x = prng.normal(prng.prng_key(seed=1), shape=(n,)) + 10

# Compute v^\top @ log(A) @ v via Lanczos
received = integrand(x)

# Compute the "true" value of v^\top @ log(A) @ v via eigenvalues
eigvals, eigvecs = linalg.eigh(A)
logA = eigvecs @ linalg.diagonal_matrix(np.log(eigvals)) @ eigvecs.T
expected = x.T @ logA @ x

# They should be identical
assert np.allclose(received, expected)

0 comments on commit 8c2a6f5

Please sign in to comment.