From d7d2f36444ac9e8598ff358b1cf0a48f8b8e0add Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Nicholas=20Kr=C3=A4mer?= Date: Thu, 29 Aug 2024 12:24:32 +0200 Subject: [PATCH] Make integrand_funm_tridiag_sym expect algorithms as inputs so the user can choose (e.g.) a reorthogonalisation mode (#210) * Make integrand_funm_tridiag_sym expect a decomposition as an input to allow switching between (e.g.) reorthogonalisation modes * Mark the next todos in the source --- matfree/funm.py | 38 +++++++++++++++---- .../test_integrand_funm_sym_logdet.py | 15 ++++---- tutorials/1_log_determinants.py | 5 ++- tutorials/2_pytree_logdeterminants.py | 5 ++- 4 files changed, 44 insertions(+), 19 deletions(-) diff --git a/matfree/funm.py b/matfree/funm.py index e432bc2..8c644ef 100644 --- a/matfree/funm.py +++ b/matfree/funm.py @@ -137,7 +137,6 @@ def estimate(matvec: Callable, vec, *parameters): length = linalg.vector_norm(vec) vec /= length Q, matrix, *_ = tridiag_sym(matvec, vec, *parameters) - # matrix = _todense_tridiag_sym(diag, off_diag) funm = dense_funm(matrix) e1 = np.eye(len(matrix))[0, :] @@ -177,22 +176,41 @@ def estimate(matvec: Callable, vec, *parameters): return estimate -def integrand_funm_sym_logdet(order, /): +def integrand_funm_sym_logdet(tridiag_sym: Callable, /): """Construct the integrand for the log-determinant. This function assumes a symmetric, positive definite matrix. + + Parameters + ---------- + tridiag_sym + An implementation of tridiagonalisation. + E.g., the output of + [decomp.tridiag_sym][matfree.decomp.tridiag_sym]. + """ - return integrand_funm_sym(np.log, order) + dense_funm = dense_funm_sym_eigh(np.log) + return integrand_funm_sym(dense_funm, tridiag_sym) -def integrand_funm_sym(matfun, order, /): +def integrand_funm_sym(dense_funm, tridiag_sym, /): """Construct the integrand for matrix-function-trace estimation. This function assumes a symmetric matrix. + + Parameters + ---------- + dense_funm + An implementation of a function of a dense matrix. + For example, the output of + [funm.dense_funm_sym_eigh][matfree.funm.dense_funm_sym_eigh] + [funm.dense_funm_schur][matfree.funm.dense_funm_schur] + tridiag_sym + An implementation of tridiagonalisation. + E.g., the output of + [decomp.tridiag_sym][matfree.decomp.tridiag_sym]. + """ - # Todo: expect these to be passed by the user. - dense_funm = dense_funm_sym_eigh(matfun) - algorithm = decomp.tridiag_sym(order, materialize=True) def quadform(matvec, v0, *parameters): v0_flat, v_unflatten = tree_util.ravel_pytree(v0) @@ -205,7 +223,7 @@ def matvec_flat(v_flat, *p): flat, unflatten = tree_util.ravel_pytree(Av) return flat - _, dense, *_ = algorithm(matvec_flat, v0_flat, *parameters) + _, dense, *_ = tridiag_sym(matvec_flat, v0_flat, *parameters) fA = dense_funm(dense) e1 = np.eye(len(fA))[0, :] @@ -214,6 +232,7 @@ def matvec_flat(v_flat, *p): return quadform +# todo: expect bidiag() to be passed here def integrand_funm_product_logdet(depth, /): r"""Construct the integrand for the log-determinant of a matrix-product. @@ -222,6 +241,7 @@ def integrand_funm_product_logdet(depth, /): return integrand_funm_product(np.log, depth) +# todo: expect bidiag() to be passed here def integrand_funm_product_schatten_norm(power, depth, /): r"""Construct the integrand for the $p$-th power of the Schatten-p norm.""" @@ -232,6 +252,8 @@ def matfun(x): return integrand_funm_product(matfun, depth) +# todo: expect bidiag() to be passed here +# todo: expect dense_funm_svd() to be passed here def integrand_funm_product(matfun, depth, /): r"""Construct the integrand for matrix-function-trace estimation. diff --git a/tests/test_funm/test_integrand_funm_sym_logdet.py b/tests/test_funm/test_integrand_funm_sym_logdet.py index 4dfdbfa..dde4be7 100644 --- a/tests/test_funm/test_integrand_funm_sym_logdet.py +++ b/tests/test_funm/test_integrand_funm_sym_logdet.py @@ -1,11 +1,10 @@ """Tests for Lanczos functionality.""" -from matfree import funm, stochtrace, test_util +from matfree import decomp, funm, stochtrace, test_util from matfree.backend import linalg, np, prng, testing -@testing.fixture() -def A(n, num_significant_eigvals): +def make_A(n, num_significant_eigvals): """Make a positive definite matrix with certain spectrum.""" # 'Invent' a spectrum. Use the number of pre-defined eigenvalues. d = np.arange(n) / n + 1.0 @@ -19,9 +18,9 @@ def A(n, num_significant_eigvals): @testing.parametrize("order", [10]) # usually: ~1.5 * num_significant_eigvals. # But logdet seems to converge sooo much faster. -def test_logdet_spd(A, order): +def test_logdet_spd(n, num_significant_eigvals, order): """Assert that the log-determinant estimation matches the true log-determinant.""" - n, _ = np.shape(A) + A = make_A(n, num_significant_eigvals) def matvec(x): return {"fx": A @ x["fx"]} @@ -29,7 +28,8 @@ def matvec(x): key = prng.prng_key(1) args_like = {"fx": np.ones((n,), dtype=float)} sampler = stochtrace.sampler_normal(args_like, num=10) - integrand = funm.integrand_funm_sym_logdet(order) + tridiag_sym = decomp.tridiag_sym(order, materialize=True) + integrand = funm.integrand_funm_sym_logdet(tridiag_sym) estimate = stochtrace.estimator(integrand, sampler) received = estimate(matvec, key) @@ -49,7 +49,8 @@ def test_logdet_spd_exact_for_full_order_lanczos(n): # Set up max-order Lanczos approximation inside SLQ for the matrix-logarithm order = n - 1 - integrand = funm.integrand_funm_sym_logdet(order) + tridiag_sym = decomp.tridiag_sym(order, materialize=True) + integrand = funm.integrand_funm_sym_logdet(tridiag_sym) # Construct a vector without that does not have expected 2-norm equal to "dim" x = prng.normal(prng.prng_key(seed=1), shape=(n,)) + 10 diff --git a/tutorials/1_log_determinants.py b/tutorials/1_log_determinants.py index 0aa6ed3..3b9454e 100644 --- a/tutorials/1_log_determinants.py +++ b/tutorials/1_log_determinants.py @@ -7,7 +7,7 @@ import jax import jax.numpy as jnp -from matfree import funm, stochtrace +from matfree import decomp, funm, stochtrace # Set up a matrix. @@ -27,7 +27,8 @@ def matvec(x): # Estimate log-determinants with stochastic Lanczos quadrature. order = 3 -problem = funm.integrand_funm_sym_logdet(order) +tridiag_sym = decomp.tridiag_sym(order) +problem = funm.integrand_funm_sym_logdet(tridiag_sym) sampler = stochtrace.sampler_normal(x_like, num=1_000) estimator = stochtrace.estimator(problem, sampler=sampler) logdet = estimator(matvec, jax.random.PRNGKey(1)) diff --git a/tutorials/2_pytree_logdeterminants.py b/tutorials/2_pytree_logdeterminants.py index bd00cd3..b90cfe9 100644 --- a/tutorials/2_pytree_logdeterminants.py +++ b/tutorials/2_pytree_logdeterminants.py @@ -8,7 +8,7 @@ import jax import jax.numpy as jnp -from matfree import funm, stochtrace +from matfree import decomp, funm, stochtrace # Create a test-problem: a function that maps a pytree (dict) to a pytree (tuple). # Its (regularised) Gauss--Newton Hessian shall be the matrix-vector product @@ -53,7 +53,8 @@ def fun(fx, /): matvec = make_matvec(alpha=0.1) order = 3 -integrand = funm.integrand_funm_sym_logdet(order) +tridiag_sym = decomp.tridiag_sym(order) +integrand = funm.integrand_funm_sym_logdet(tridiag_sym) sample_fun = stochtrace.sampler_normal(f0, num=10) estimator = stochtrace.estimator(integrand, sampler=sample_fun) key = jax.random.PRNGKey(1)