Skip to content

Commit

Permalink
Implement funm_via_lanczos by assuming a readily assembled Lanczos de…
Browse files Browse the repository at this point in the history
…composition
  • Loading branch information
pnkraemer committed May 29, 2024
1 parent e624132 commit 4095933
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 10 deletions.
20 changes: 13 additions & 7 deletions matfree/funm.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
--------
>>> import jax.random
>>> import jax.numpy as jnp
>>> from matfree import decomp
>>>
>>> jnp.set_printoptions(1)
>>>
Expand All @@ -12,12 +13,12 @@
>>> v = jax.random.normal(jax.random.PRNGKey(2), shape=(10,))
>>>
>>> # Compute a matrix-logarithm with Lanczos' algorithm
>>> matfun_vec = funm_lanczos_sym(jnp.log, 4, lambda s: A @ s)
>>> tridiag = decomp.tridiag_sym(lambda s: A @ s, 4)
>>> matfun_vec = funm_lanczos_sym(jnp.log, tridiag)
>>> matfun_vec(v)
Array([-4.1, -1.3, -2.2, -2.1, -1.2, -3.3, -0.2, 0.3, 0.7, 0.9], dtype=float32)
"""

from matfree import decomp
from matfree.backend import containers, control_flow, func, linalg, np
from matfree.backend.typing import Array, Callable

Expand Down Expand Up @@ -94,20 +95,25 @@ def matvec(vec, *parameters):
return matvec


# todo: if we pass decomp.tridiag_sym instead of order & matvec,
# the user gets more control over questions like reorthogonalisation
def funm_lanczos_sym(matfun: Callable, order: int, matvec: Callable, /) -> Callable:
def funm_lanczos_sym(matfun: Callable, tridiag_sym: Callable, /) -> Callable:
"""Implement a matrix-function-vector product via Lanczos' tridiagonalisation.
This algorithm uses Lanczos' tridiagonalisation
and therefore applies only to symmetric matrices.
Parameters
----------
matfun
Matrix function.
tridiag_sym
Tridiagonalisation implementation.
Output of [decomp.tridiag_sym][matfree.decomp.tridiag_sym].
"""
algorithm = decomp.tridiag_sym(matvec, order)

def estimate(vec, *parameters):
length = linalg.vector_norm(vec)
vec /= length
(basis, (diag, off_diag)), _ = algorithm(vec, *parameters)
(basis, (diag, off_diag)), _ = tridiag_sym(vec, *parameters)
eigvals, eigvecs = _eigh_tridiag_sym(diag, off_diag)

fx_eigvals = func.vmap(matfun)(eigvals)
Expand Down
7 changes: 4 additions & 3 deletions tests/test_funm/test_funm_lanczos_sym.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
"""Test matrix-function-vector products via Lanczos' algorithm."""

from matfree import funm, test_util
from matfree import decomp, funm, test_util
from matfree.backend import linalg, np, prng


def test_funm_lanczos_sym(n=11):
def test_funm_lanczos_sym_matches_eigh_implementation(n=11):
"""Test matrix-function-vector products via Lanczos' algorithm."""
# Create a test-problem: matvec, matrix function,
# vector, and parameters (a matrix).
Expand All @@ -27,6 +27,7 @@ def fun(x):

# Compute the matrix-function vector product
order = 6
matfun_vec = funm.funm_lanczos_sym(fun, order, matvec)
lanczos = decomp.tridiag_sym(matvec, order)
matfun_vec = funm.funm_lanczos_sym(fun, lanczos)
received = matfun_vec(v, matrix)
assert np.allclose(expected, received, atol=1e-6)

0 comments on commit 4095933

Please sign in to comment.