Skip to content

Commit

Permalink
Implement Lanczos-funm's via combining decompositions with various sm…
Browse files Browse the repository at this point in the history
…all funm's
  • Loading branch information
pnkraemer committed May 29, 2024
1 parent 4095933 commit 0ef76b9
Show file tree
Hide file tree
Showing 3 changed files with 43 additions and 19 deletions.
4 changes: 4 additions & 0 deletions matfree/backend/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,3 +77,7 @@ def solve(A, b, /):

def cg(Av, b, /):
return jax.scipy.sparse.linalg.cg(Av, b)


def funm_schur(A, f, /):
return jax.scipy.linalg.funm(A, f)
47 changes: 33 additions & 14 deletions matfree/funm.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,9 @@
>>> v = jax.random.normal(jax.random.PRNGKey(2), shape=(10,))
>>>
>>> # Compute a matrix-logarithm with Lanczos' algorithm
>>> matfun = dense_funm_sym_eigh(jnp.log)
>>> tridiag = decomp.tridiag_sym(lambda s: A @ s, 4)
>>> matfun_vec = funm_lanczos_sym(jnp.log, tridiag)
>>> matfun_vec = funm_lanczos_sym(matfun, tridiag)
>>> matfun_vec(v)
Array([-4.1, -1.3, -2.2, -2.1, -1.2, -3.3, -0.2, 0.3, 0.7, 0.9], dtype=float32)
"""
Expand Down Expand Up @@ -95,41 +96,59 @@ def matvec(vec, *parameters):
return matvec


def funm_lanczos_sym(matfun: Callable, tridiag_sym: Callable, /) -> Callable:
def funm_lanczos_sym(dense_funm: Callable, tridiag_sym: Callable, /) -> Callable:
"""Implement a matrix-function-vector product via Lanczos' tridiagonalisation.
This algorithm uses Lanczos' tridiagonalisation
and therefore applies only to symmetric matrices.
Parameters
----------
matfun
Matrix function.
dense_funm
An implementation of a function of a dense matrix.
For example, the output of
[decomp.dense_funm_sym_eigh][matfree.decomp.dense_funm_sym_eigh]
[decomp.dense_funm_schur][matfree.decomp.dense_funm_schur]
tridiag_sym
Tridiagonalisation implementation.
Output of [decomp.tridiag_sym][matfree.decomp.tridiag_sym].
An implementation of tridiagonalisation.
E.g., the output of
[decomp.tridiag_sym][matfree.decomp.tridiag_sym].
"""

def estimate(vec, *parameters):
length = linalg.vector_norm(vec)
vec /= length
(basis, (diag, off_diag)), _ = tridiag_sym(vec, *parameters)
eigvals, eigvecs = _eigh_tridiag_sym(diag, off_diag)
matrix = _todense_tridiag_sym(diag, off_diag)

fx_eigvals = func.vmap(matfun)(eigvals)
return length * (basis.T @ (eigvecs @ (fx_eigvals * eigvecs[0, :])))
funm = dense_funm(matrix)
e1 = np.eye(len(matrix))[0, :]
return length * (basis.T @ funm @ e1)

return estimate


def _eigh_tridiag_sym(diag, off_diag):
def dense_funm_sym_eigh(matfun):
def fun(dense_matrix):
eigvals, eigvecs = linalg.eigh(dense_matrix)
fx_eigvals = func.vmap(matfun)(eigvals)
return eigvecs @ linalg.diagonal(fx_eigvals) @ eigvecs.T

return fun


def dense_funm_schur(matfun):
def fun(dense_matrix):
return linalg.funm_schur(dense_matrix, matfun)

return fun


def _todense_tridiag_sym(diag, off_diag):
# todo: once jax supports eigh_tridiagonal(eigvals_only=False),
# use it here. Until then: an eigen-decomposition of size (order + 1)
# does not hurt too much...
diag = linalg.diagonal_matrix(diag)
offdiag1 = linalg.diagonal_matrix(off_diag, -1)
offdiag2 = linalg.diagonal_matrix(off_diag, 1)
dense_matrix = diag + offdiag1 + offdiag2

eigvals, eigvecs = linalg.eigh(dense_matrix)
return eigvals, eigvecs
return diag + offdiag1 + offdiag2
11 changes: 6 additions & 5 deletions tests/test_funm/test_funm_lanczos_sym.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
"""Test matrix-function-vector products via Lanczos' algorithm."""

from matfree import decomp, funm, test_util
from matfree.backend import linalg, np, prng
from matfree.backend import linalg, np, prng, testing


def test_funm_lanczos_sym_matches_eigh_implementation(n=11):
@testing.parametrize("dense_funm", [funm.dense_funm_sym_eigh, funm.dense_funm_schur])
def test_funm_lanczos_sym_matches_eigh_implementation(dense_funm, n=11):
"""Test matrix-function-vector products via Lanczos' algorithm."""
# Create a test-problem: matvec, matrix function,
# vector, and parameters (a matrix).
Expand All @@ -26,8 +27,8 @@ def fun(x):
expected = log_matrix @ v

# Compute the matrix-function vector product
order = 6
lanczos = decomp.tridiag_sym(matvec, order)
matfun_vec = funm.funm_lanczos_sym(fun, lanczos)
dense_funm = dense_funm(fun)
lanczos = decomp.tridiag_sym(matvec, 6)
matfun_vec = funm.funm_lanczos_sym(dense_funm, lanczos)
received = matfun_vec(v, matrix)
assert np.allclose(expected, received, atol=1e-6)

0 comments on commit 0ef76b9

Please sign in to comment.