Skip to content

Commit

Permalink
Implement decomp.hessenberg() via the Arnoldi decomposition (#192)
Browse files Browse the repository at this point in the history
* Implement hessenberg() via Arnoldi

* Move outer() from np to linalg

* Simplify inner product code

* Update the README to mention Arnoldi

* Write a docstring for Hessenberg()
  • Loading branch information
pnkraemer authored May 27, 2024
1 parent cb86c32 commit 9a7acbd
Show file tree
Hide file tree
Showing 10 changed files with 166 additions and 16 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ Builds on [JAX](https://jax.readthedocs.io/en/latest/).

- ⚡ Stochastic **trace estimation** including batching, control variates, and uncertainty quantification
- ⚡ A stand-alone implementation of **stochastic Lanczos quadrature**
- ⚡ Matrix-decomposition algorithms for **large sparse eigenvalue problems**
- ⚡ Matrix-decomposition algorithms for **large sparse eigenvalue problems**: tridiagonalisation, bidiagonalisation, Hessenberg factorisation via Lanczos and Arnoldi iterations
- ⚡ Polynomial methods for approximating **functions of large matrices**
- ⚡ Partial Cholesky **preconditioners** with and without pivoting

Expand Down
13 changes: 11 additions & 2 deletions matfree/backend/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,17 @@ def slogdet(x, /):
return jnp.linalg.slogdet(x)


def vecdot(x1, x2, /):
return jnp.dot(x1, x2)
def inner(x1, x2, /):
# todo: distinguish vecdot, vdot, dot, and matmul?
return jnp.inner(x1, x2)


def outer(a, b, /):
return jnp.outer(a, b)


def hilbert(n, /):
return jax.scipy.linalg.hilbert(n)


def diagonal(x, /, offset=0):
Expand Down
4 changes: 2 additions & 2 deletions matfree/backend/testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@ def check_grads(fun, /, args, *, order, atol, rtol):
return jax.test_util.check_grads(fun, args, order=order, atol=atol, rtol=rtol)


def raises(err, /):
return pytest.raises(err)
def raises(err, /, match):
return pytest.raises(err, match=match)


def warns(warning, /):
Expand Down
62 changes: 61 additions & 1 deletion matfree/decomp.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,7 +224,7 @@ def _gram_schmidt_classical(vec, vectors): # Gram-Schmidt


def _gram_schmidt_classical_step(vec1, vec2):
coeff = linalg.vecdot(vec1, vec2)
coeff = linalg.inner(vec1, vec2)
vec_ortho = vec1 - coeff * vec2
return vec_ortho, coeff

Expand Down Expand Up @@ -276,3 +276,63 @@ def _eigh_tridiag_sym(diag, off_diag):
dense_matrix = diag + offdiag1 + offdiag2
eigvals, eigvecs = linalg.eigh(dense_matrix)
return eigvals, eigvecs


def hessenberg(matvec, krylov_depth, /, *, reortho: str = "full"):
"""Construct an implementation of the Arnoldi iteration."""
reortho_expected = ["none", "full"]
if reortho not in reortho_expected:
msg = f"Unexpected input for {reortho}: either of {reortho_expected} expected."
raise TypeError(msg)

def estimate(v, *params):
return _forward(matvec, krylov_depth, v, *params, reortho=reortho)

return estimate


def _forward(matvec, krylov_depth, v, *params, reortho: str):
if krylov_depth < 1 or krylov_depth > len(v):
msg = f"Parameter depth {krylov_depth} is outside the expected range"
raise ValueError(msg)

# Initialise the variables
(n,), k = np.shape(v), krylov_depth
Q = np.zeros((n, k), dtype=v.dtype)
H = np.zeros((k, k), dtype=v.dtype)
initlength = linalg.vector_norm(v)
init = (Q, H, v, initlength)

# Fix the step function
def forward_step(i, val):
return _forward_step(*val, matvec, *params, idx=i, reortho=reortho)

# Loop and return
Q, H, v, _length = control_flow.fori_loop(0, k, forward_step, init)
return Q, H, v, 1 / initlength


def _forward_step(Q, H, v, length, matvec, *params, idx, reortho: str):
# Save
v /= length
Q = Q.at[:, idx].set(v)

# Evaluate
v = matvec(v, *params)

# Orthonormalise
h = Q.T @ v
v = v - Q @ h

# Re-orthonormalise
if reortho != "none":
v = v - Q @ (Q.T @ v)

# Read the length
length = linalg.vector_norm(v)

# Save
h = h.at[idx + 1].set(length)
H = H.at[:, idx].set(h)

return Q, H, v, length
6 changes: 3 additions & 3 deletions matfree/low_rank.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ def matrix_column(i):

def body(i, L):
element = matrix_element(i, i)
l_ii = np.sqrt(element - linalg.vecdot(L[i], L[i]))
l_ii = np.sqrt(element - linalg.inner(L[i], L[i]))

column = matrix_column(i)
l_ji = column - L @ L[i, :]
Expand Down Expand Up @@ -141,7 +141,7 @@ def body(i, carry):
diagonal = matrix_diagonal_p(permute=P_matrix)

# Find the largest entry for the residuals
residual_diag = diagonal - func.vmap(linalg.vecdot)(L, L)
residual_diag = diagonal - func.vmap(linalg.inner)(L, L)
res = np.abs(residual_diag)
k = np.argmax(res)

Expand All @@ -158,7 +158,7 @@ def body(i, carry):
# (The first line could also be accessed via
# residual_diag[k], but it might
# be more readable to do it again)
l_ii_squared = element - linalg.vecdot(L[i], L[i])
l_ii_squared = element - linalg.inner(L[i], L[i])
l_ii = np.sqrt(l_ii_squared)
l_ji = column - L @ L[i, :]
l_ji /= l_ii
Expand Down
6 changes: 3 additions & 3 deletions matfree/stochtrace.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def integrand(v, *parameters):
Qv = matvec(v, *parameters)
v_flat, unflatten = tree_util.ravel_pytree(v)
Qv_flat, _unflatten = tree_util.ravel_pytree(Qv)
return linalg.vecdot(v_flat, Qv_flat)
return linalg.inner(v_flat, Qv_flat)

return integrand

Expand All @@ -75,7 +75,7 @@ def integrand(v, *parameters):
Qv = matvec(v, *parameters)
v_flat, unflatten = tree_util.ravel_pytree(v)
Qv_flat, _unflatten = tree_util.ravel_pytree(Qv)
trace_form = linalg.vecdot(v_flat, Qv_flat)
trace_form = linalg.inner(v_flat, Qv_flat)
diagonal_form = unflatten(v_flat * Qv_flat)
return {"trace": trace_form, "diagonal": diagonal_form}

Expand All @@ -88,7 +88,7 @@ def integrand_frobeniusnorm_squared(matvec, /):
def integrand(vec, *parameters):
x = matvec(vec, *parameters)
v_flat, unflatten = tree_util.ravel_pytree(x)
return linalg.vecdot(v_flat, v_flat)
return linalg.inner(v_flat, v_flat)

return integrand

Expand Down
4 changes: 2 additions & 2 deletions matfree/stochtrace_funm.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def matvec_flat(v_flat, *p):
# Since Q orthogonal (orthonormal) to v0, Q v = Q[0],
# and therefore (Q v)^T f(D) (Qv) = Q[0] * f(diag) * Q[0]
fx_eigvals = func.vmap(matfun)(eigvals)
return length**2 * linalg.vecdot(eigvecs[0, :], fx_eigvals * eigvecs[0, :])
return length**2 * linalg.inner(eigvecs[0, :], fx_eigvals * eigvecs[0, :])

return quadform

Expand Down Expand Up @@ -109,7 +109,7 @@ def vecmat_flat(w_flat):
# and therefore (Q v)^T f(D) (Qv) = Q[0] * f(diag) * Q[0]
eigvals, eigvecs = S**2, Vt.T
fx_eigvals = func.vmap(matfun)(eigvals)
return length**2 * linalg.vecdot(eigvecs[0, :], fx_eigvals * eigvecs[0, :])
return length**2 * linalg.inner(eigvecs[0, :], fx_eigvals * eigvecs[0, :])

return quadform

Expand Down
4 changes: 2 additions & 2 deletions tests/test_decomp/test_bidiag.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def test_error_too_high_depth(A):
nrows, ncols = np.shape(A)
max_depth = min(nrows, ncols) - 1

with testing.raises(ValueError):
with testing.raises(ValueError, match=""):

def eye(v):
return v
Expand All @@ -82,7 +82,7 @@ def eye(v):
def test_error_too_low_depth(A):
"""Assert that a ValueError is raised when the depth is negative."""
min_depth = 0
with testing.raises(ValueError):
with testing.raises(ValueError, match=""):

def eye(v):
return v
Expand Down
81 changes: 81 additions & 0 deletions tests/test_decomp/test_hessenberg.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
"""Tests for Hessenberg factorisations (-> Arnoldi)."""

from matfree import decomp
from matfree.backend import linalg, np, prng, testing


@testing.parametrize("nrows", [10])
@testing.parametrize("krylov_depth", [1, 5, 10])
@testing.parametrize("reortho", ["none", "full"])
@testing.parametrize("dtype", [float])
def test_decomposition_is_satisfied(nrows, krylov_depth, reortho, dtype):
# Create a well-conditioned test-matrix
A = prng.normal(prng.prng_key(1), shape=(nrows, nrows), dtype=dtype)
v = prng.normal(prng.prng_key(2), shape=(nrows,), dtype=dtype)

# Decompose
algorithm = decomp.hessenberg(lambda s, p: p @ s, krylov_depth, reortho=reortho)
Q, H, r, c = algorithm(v, A)

# Assert shapes
assert Q.shape == (nrows, krylov_depth)
assert H.shape == (krylov_depth, krylov_depth)
assert r.shape == (nrows,)
assert c.shape == ()

# Tie the test-strictness to the floating point accuracy
small_value = np.sqrt(np.finfo_eps(np.dtype(H)))
tols = {"atol": small_value, "rtol": small_value}

# Test the decompositions
e0, ek = np.eye(krylov_depth)[[0, -1], :]
assert np.allclose(A @ Q - Q @ H - linalg.outer(r, ek), 0.0, **tols)
assert np.allclose(Q.T.conj() @ Q - np.eye(krylov_depth), 0.0, **tols)
assert np.allclose(Q @ e0, c * v, **tols)


@testing.parametrize("nrows", [10])
@testing.parametrize("krylov_depth", [1, 5, 10])
@testing.parametrize("reortho", ["full"])
def test_reorthogonalisation_improves_the_estimate(nrows, krylov_depth, reortho):
# Create an ill-conditioned test-matrix (that requires reortho=True)
A = linalg.hilbert(nrows)
v = prng.normal(prng.prng_key(2), shape=(nrows,))

# Decompose
algorithm = decomp.hessenberg(lambda s, p: p @ s, krylov_depth, reortho=reortho)
Q, H, r, c = algorithm(v, A)

# Assert shapes
assert Q.shape == (nrows, krylov_depth)
assert H.shape == (krylov_depth, krylov_depth)
assert r.shape == (nrows,)
assert c.shape == ()

# Tie the test-strictness to the floating point accuracy
small_value = np.sqrt(np.finfo_eps(np.dtype(H)))
tols = {"atol": small_value, "rtol": small_value}

# Test the decompositions
e0, ek = np.eye(krylov_depth)[[0, -1], :]
assert np.allclose(A @ Q - Q @ H - linalg.outer(r, ek), 0.0, **tols)
assert np.allclose(Q.T @ Q - np.eye(krylov_depth), 0.0, **tols)
assert np.allclose(Q @ e0, c * v, **tols)


def test_raises_error_for_wrong_depth_too_small():
algorithm = decomp.hessenberg(lambda s: s, 0, reortho="none")
with testing.raises(ValueError, match="depth"):
_ = algorithm(np.ones((2,)))


def test_raises_error_for_wrong_depth_too_high():
algorithm = decomp.hessenberg(lambda s: s, 3, reortho="none")
with testing.raises(ValueError, match="depth"):
_ = algorithm(np.ones((2,)))


@testing.parametrize("reortho_wrong", [True, "full_with_sparsity", "None"])
def test_raises_error_for_wrong_reorthogonalisation_flag(reortho_wrong):
with testing.raises(TypeError, match="Unexpected input"):
_ = decomp.hessenberg(lambda s: s, 1, reortho=reortho_wrong)
File renamed without changes.

0 comments on commit 9a7acbd

Please sign in to comment.