Skip to content

Commit

Permalink
Make integrand_funm_tridiag_sym expect algorithms as inputs so the us…
Browse files Browse the repository at this point in the history
…er can choose (e.g.) a reorthogonalisation mode (#210)

* Make integrand_funm_tridiag_sym expect a decomposition as an input to allow switching between (e.g.) reorthogonalisation modes

* Mark the next todos in the source
  • Loading branch information
pnkraemer authored Aug 29, 2024
1 parent 16f261b commit d7d2f36
Show file tree
Hide file tree
Showing 4 changed files with 44 additions and 19 deletions.
38 changes: 30 additions & 8 deletions matfree/funm.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,6 @@ def estimate(matvec: Callable, vec, *parameters):
length = linalg.vector_norm(vec)
vec /= length
Q, matrix, *_ = tridiag_sym(matvec, vec, *parameters)
# matrix = _todense_tridiag_sym(diag, off_diag)

funm = dense_funm(matrix)
e1 = np.eye(len(matrix))[0, :]
Expand Down Expand Up @@ -177,22 +176,41 @@ def estimate(matvec: Callable, vec, *parameters):
return estimate


def integrand_funm_sym_logdet(order, /):
def integrand_funm_sym_logdet(tridiag_sym: Callable, /):
"""Construct the integrand for the log-determinant.
This function assumes a symmetric, positive definite matrix.
Parameters
----------
tridiag_sym
An implementation of tridiagonalisation.
E.g., the output of
[decomp.tridiag_sym][matfree.decomp.tridiag_sym].
"""
return integrand_funm_sym(np.log, order)
dense_funm = dense_funm_sym_eigh(np.log)
return integrand_funm_sym(dense_funm, tridiag_sym)


def integrand_funm_sym(matfun, order, /):
def integrand_funm_sym(dense_funm, tridiag_sym, /):
"""Construct the integrand for matrix-function-trace estimation.
This function assumes a symmetric matrix.
Parameters
----------
dense_funm
An implementation of a function of a dense matrix.
For example, the output of
[funm.dense_funm_sym_eigh][matfree.funm.dense_funm_sym_eigh]
[funm.dense_funm_schur][matfree.funm.dense_funm_schur]
tridiag_sym
An implementation of tridiagonalisation.
E.g., the output of
[decomp.tridiag_sym][matfree.decomp.tridiag_sym].
"""
# Todo: expect these to be passed by the user.
dense_funm = dense_funm_sym_eigh(matfun)
algorithm = decomp.tridiag_sym(order, materialize=True)

def quadform(matvec, v0, *parameters):
v0_flat, v_unflatten = tree_util.ravel_pytree(v0)
Expand All @@ -205,7 +223,7 @@ def matvec_flat(v_flat, *p):
flat, unflatten = tree_util.ravel_pytree(Av)
return flat

_, dense, *_ = algorithm(matvec_flat, v0_flat, *parameters)
_, dense, *_ = tridiag_sym(matvec_flat, v0_flat, *parameters)

fA = dense_funm(dense)
e1 = np.eye(len(fA))[0, :]
Expand All @@ -214,6 +232,7 @@ def matvec_flat(v_flat, *p):
return quadform


# todo: expect bidiag() to be passed here
def integrand_funm_product_logdet(depth, /):
r"""Construct the integrand for the log-determinant of a matrix-product.
Expand All @@ -222,6 +241,7 @@ def integrand_funm_product_logdet(depth, /):
return integrand_funm_product(np.log, depth)


# todo: expect bidiag() to be passed here
def integrand_funm_product_schatten_norm(power, depth, /):
r"""Construct the integrand for the $p$-th power of the Schatten-p norm."""

Expand All @@ -232,6 +252,8 @@ def matfun(x):
return integrand_funm_product(matfun, depth)


# todo: expect bidiag() to be passed here
# todo: expect dense_funm_svd() to be passed here
def integrand_funm_product(matfun, depth, /):
r"""Construct the integrand for matrix-function-trace estimation.
Expand Down
15 changes: 8 additions & 7 deletions tests/test_funm/test_integrand_funm_sym_logdet.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
"""Tests for Lanczos functionality."""

from matfree import funm, stochtrace, test_util
from matfree import decomp, funm, stochtrace, test_util
from matfree.backend import linalg, np, prng, testing


@testing.fixture()
def A(n, num_significant_eigvals):
def make_A(n, num_significant_eigvals):
"""Make a positive definite matrix with certain spectrum."""
# 'Invent' a spectrum. Use the number of pre-defined eigenvalues.
d = np.arange(n) / n + 1.0
Expand All @@ -19,17 +18,18 @@ def A(n, num_significant_eigvals):
@testing.parametrize("order", [10])
# usually: ~1.5 * num_significant_eigvals.
# But logdet seems to converge sooo much faster.
def test_logdet_spd(A, order):
def test_logdet_spd(n, num_significant_eigvals, order):
"""Assert that the log-determinant estimation matches the true log-determinant."""
n, _ = np.shape(A)
A = make_A(n, num_significant_eigvals)

def matvec(x):
return {"fx": A @ x["fx"]}

key = prng.prng_key(1)
args_like = {"fx": np.ones((n,), dtype=float)}
sampler = stochtrace.sampler_normal(args_like, num=10)
integrand = funm.integrand_funm_sym_logdet(order)
tridiag_sym = decomp.tridiag_sym(order, materialize=True)
integrand = funm.integrand_funm_sym_logdet(tridiag_sym)
estimate = stochtrace.estimator(integrand, sampler)
received = estimate(matvec, key)

Expand All @@ -49,7 +49,8 @@ def test_logdet_spd_exact_for_full_order_lanczos(n):

# Set up max-order Lanczos approximation inside SLQ for the matrix-logarithm
order = n - 1
integrand = funm.integrand_funm_sym_logdet(order)
tridiag_sym = decomp.tridiag_sym(order, materialize=True)
integrand = funm.integrand_funm_sym_logdet(tridiag_sym)

# Construct a vector without that does not have expected 2-norm equal to "dim"
x = prng.normal(prng.prng_key(seed=1), shape=(n,)) + 10
Expand Down
5 changes: 3 additions & 2 deletions tutorials/1_log_determinants.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import jax
import jax.numpy as jnp

from matfree import funm, stochtrace
from matfree import decomp, funm, stochtrace

# Set up a matrix.

Expand All @@ -27,7 +27,8 @@ def matvec(x):
# Estimate log-determinants with stochastic Lanczos quadrature.

order = 3
problem = funm.integrand_funm_sym_logdet(order)
tridiag_sym = decomp.tridiag_sym(order)
problem = funm.integrand_funm_sym_logdet(tridiag_sym)
sampler = stochtrace.sampler_normal(x_like, num=1_000)
estimator = stochtrace.estimator(problem, sampler=sampler)
logdet = estimator(matvec, jax.random.PRNGKey(1))
Expand Down
5 changes: 3 additions & 2 deletions tutorials/2_pytree_logdeterminants.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import jax
import jax.numpy as jnp

from matfree import funm, stochtrace
from matfree import decomp, funm, stochtrace

# Create a test-problem: a function that maps a pytree (dict) to a pytree (tuple).
# Its (regularised) Gauss--Newton Hessian shall be the matrix-vector product
Expand Down Expand Up @@ -53,7 +53,8 @@ def fun(fx, /):

matvec = make_matvec(alpha=0.1)
order = 3
integrand = funm.integrand_funm_sym_logdet(order)
tridiag_sym = decomp.tridiag_sym(order)
integrand = funm.integrand_funm_sym_logdet(tridiag_sym)
sample_fun = stochtrace.sampler_normal(f0, num=10)
estimator = stochtrace.estimator(integrand, sampler=sample_fun)
key = jax.random.PRNGKey(1)
Expand Down

0 comments on commit d7d2f36

Please sign in to comment.