From fdccbb4081b1c599695691c1c46a38f90601b3fb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Nicholas=20Kr=C3=A4mer?= Date: Mon, 27 May 2024 11:01:06 +0200 Subject: [PATCH] Create dedicated `funm` and `funm_trace` modules for functions of matrices (#188) * Create matfree/funm.py and matfree/funm_trace.py to collect matrix-function functionality * Update the documentation --- matfree/{polynomial.py => funm.py} | 63 +++++- matfree/funm_trace.py | 132 ++++++++++++ matfree/hutchinson.py | 2 +- matfree/lanczos.py | 193 +++--------------- .../test_funm_chebyshev.py} | 6 +- .../test_funm_lanczos_spd.py} | 8 +- .../test_integrand_logdet_product.py | 6 +- .../test_integrand_logdet_spd.py | 6 +- .../test_integrand_schatten_norm.py | 4 +- tutorials/1_log_determinants.py | 6 +- tutorials/2_pytree_logdeterminants.py | 4 +- 11 files changed, 235 insertions(+), 195 deletions(-) rename matfree/{polynomial.py => funm.py} (55%) create mode 100644 matfree/funm_trace.py rename tests/{test_polynomial/test_funm_vector_product_chebyshev.py => test_funm/test_funm_chebyshev.py} (84%) rename tests/{test_lanczos/test_funm_vector_product.py => test_funm/test_funm_lanczos_spd.py} (78%) rename tests/{test_lanczos => test_funm_trace}/test_integrand_logdet_product.py (93%) rename tests/{test_lanczos => test_funm_trace}/test_integrand_logdet_spd.py (92%) rename tests/{test_lanczos => test_funm_trace}/test_integrand_schatten_norm.py (92%) diff --git a/matfree/polynomial.py b/matfree/funm.py similarity index 55% rename from matfree/polynomial.py rename to matfree/funm.py index 081b7dd..d2652f9 100644 --- a/matfree/polynomial.py +++ b/matfree/funm.py @@ -1,14 +1,29 @@ -"""Approximate matrix-function-vector products with polynomial expansions. - -This module does not include Lanczos-style approximations, which are -in [matfree.lanczos][matfree.lanczos]. +"""Functions of matrices implemented as matrix-function-vector products. + + +Examples +-------- +>>> import jax.random +>>> import jax.numpy as jnp +>>> +>>> jnp.set_printoptions(1) +>>> +>>> M = jax.random.normal(jax.random.PRNGKey(1), shape=(10, 10)) +>>> A = M.T @ M +>>> v = jax.random.normal(jax.random.PRNGKey(2), shape=(10,)) +>>> +>>> # Compute a matrix-logarithm with Lanczos' algorithm +>>> matfun_vec = funm_lanczos_spd(jnp.log, 4, lambda s: A @ s) +>>> matfun_vec(v) +Array([-4. , -2.1, -2.7, -1.9, -1.3, -3.5, -0.5, -0.1, 0.3, 1.5], dtype=float32) """ -from matfree.backend import containers, control_flow, np +from matfree import lanczos +from matfree.backend import containers, control_flow, func, linalg, np from matfree.backend.typing import Array -def funm_vector_product_chebyshev(matfun, order, matvec, /): +def funm_chebyshev(matfun, order, matvec, /): """Compute a matrix-function-vector product via Chebyshev's algorithm. This function assumes that the **spectrum of the matrix-vector product @@ -56,7 +71,7 @@ def extract_func(val: _ChebyshevState): return val.interpolation alg = (0, order - 1), init_func, recursion_func, extract_func - return _funm_vector_product_polyexpand(alg) + return _funm_polyexpand(alg) def _chebyshev_nodes(n, /): @@ -64,7 +79,7 @@ def _chebyshev_nodes(n, /): return np.cos((2 * k - 1) / (2 * n) * np.pi()) -def _funm_vector_product_polyexpand(matrix_poly_alg, /): +def _funm_polyexpand(matrix_poly_alg, /): """Implement a matrix-function-vector product via a polynomial expansion.""" (lower, upper), init_func, step_func, extract_func = matrix_poly_alg @@ -78,3 +93,35 @@ def matvec(vec, *parameters): return extract_func(final_state) return matvec + + +def funm_lanczos_spd(matfun, order, matvec, /): + """Implement a matrix-function-vector product via Lanczos' algorithm. + + This algorithm uses Lanczos' tridiagonalisation with full re-orthogonalisation + and therefore applies only to symmetric, positive definite matrices. + """ + algorithm = lanczos.alg_tridiag_full_reortho(matvec, order) + + def estimate(vec, *parameters): + length = linalg.vector_norm(vec) + vec /= length + basis, (diag, off_diag) = algorithm(vec, *parameters) + eigvals, eigvecs = _eigh_tridiag(diag, off_diag) + + fx_eigvals = func.vmap(matfun)(eigvals) + return length * (basis.T @ (eigvecs @ (fx_eigvals * eigvecs[0, :]))) + + return estimate + + +def _eigh_tridiag(diag, off_diag): + # todo: once jax supports eigh_tridiagonal(eigvals_only=False), + # use it here. Until then: an eigen-decomposition of size (order + 1) + # does not hurt too much... + diag = linalg.diagonal_matrix(diag) + offdiag1 = linalg.diagonal_matrix(off_diag, -1) + offdiag2 = linalg.diagonal_matrix(off_diag, 1) + dense_matrix = diag + offdiag1 + offdiag2 + eigvals, eigvecs = linalg.eigh(dense_matrix) + return eigvals, eigvecs diff --git a/matfree/funm_trace.py b/matfree/funm_trace.py new file mode 100644 index 0000000..6709799 --- /dev/null +++ b/matfree/funm_trace.py @@ -0,0 +1,132 @@ +"""Stochastic estimation of traces of functions of matrices. + +This module extends [matfree.hutchinson][matfree.hutchinson]. + +""" + +from matfree import lanczos +from matfree.backend import func, linalg, np, tree_util + +# todo: currently, all dense matrix-functions are computed +# via eigh(). But for e.g. log and exp, we might want to do +# something else. + + +def integrand_spd_logdet(order, matvec, /): + """Construct the integrand for the log-determinant. + + This function assumes a symmetric, positive definite matrix. + """ + return integrand_spd(np.log, order, matvec) + + +def integrand_spd(matfun, order, matvec, /): + """Quadratic form for stochastic Lanczos quadrature. + + This function assumes a symmetric, positive definite matrix. + """ + + def quadform(v0, *parameters): + v0_flat, v_unflatten = tree_util.ravel_pytree(v0) + length = linalg.vector_norm(v0_flat) + v0_flat /= length + + def matvec_flat(v_flat, *p): + v = v_unflatten(v_flat) + Av = matvec(v, *p) + flat, unflatten = tree_util.ravel_pytree(Av) + return flat + + algorithm = lanczos.alg_tridiag_full_reortho(matvec_flat, order) + _, (diag, off_diag) = algorithm(v0_flat, *parameters) + eigvals, eigvecs = _eigh_tridiag(diag, off_diag) + + # Since Q orthogonal (orthonormal) to v0, Q v = Q[0], + # and therefore (Q v)^T f(D) (Qv) = Q[0] * f(diag) * Q[0] + fx_eigvals = func.vmap(matfun)(eigvals) + return length**2 * linalg.vecdot(eigvecs[0, :], fx_eigvals * eigvecs[0, :]) + + return quadform + + +def integrand_product_logdet(depth, matvec, vecmat, /): + r"""Construct the integrand for the log-determinant of a matrix-product. + + Here, "product" refers to $X = A^\top A$. + """ + return integrand_product(np.log, depth, matvec, vecmat) + + +def integrand_product_schatten_norm(power, depth, matvec, vecmat, /): + r"""Construct the integrand for the p-th power of the Schatten-p norm.""" + + def matfun(x): + """Matrix-function for Schatten-p norms.""" + return x ** (power / 2) + + return integrand_product(matfun, depth, matvec, vecmat) + + +def integrand_product(matfun, depth, matvec, vecmat, /): + r"""Construct the integrand for the trace of a function of a matrix-product. + + Instead of the trace of a function of a matrix, + compute the trace of a function of the product of matrices. + Here, "product" refers to $X = A^\top A$. + """ + + def quadform(v0, *parameters): + v0_flat, v_unflatten = tree_util.ravel_pytree(v0) + length = linalg.vector_norm(v0_flat) + v0_flat /= length + + def matvec_flat(v_flat, *p): + v = v_unflatten(v_flat) + Av = matvec(v, *p) + flat, unflatten = tree_util.ravel_pytree(Av) + return flat, tree_util.partial_pytree(unflatten) + + w0_flat, w_unflatten = func.eval_shape(matvec_flat, v0_flat) + matrix_shape = (*np.shape(w0_flat), *np.shape(v0_flat)) + + def vecmat_flat(w_flat): + w = w_unflatten(w_flat) + wA = vecmat(w, *parameters) + return tree_util.ravel_pytree(wA)[0] + + # Decompose into orthogonal-bidiag-orthogonal + algorithm = lanczos.alg_bidiag_full_reortho( + lambda v: matvec_flat(v)[0], vecmat_flat, depth, matrix_shape=matrix_shape + ) + output = algorithm(v0_flat, *parameters) + u, (d, e), vt, _ = output + + # Compute SVD of factorisation + B = _bidiagonal_dense(d, e) + _, S, Vt = linalg.svd(B, full_matrices=False) + + # Since Q orthogonal (orthonormal) to v0, Q v = Q[0], + # and therefore (Q v)^T f(D) (Qv) = Q[0] * f(diag) * Q[0] + eigvals, eigvecs = S**2, Vt.T + fx_eigvals = func.vmap(matfun)(eigvals) + return length**2 * linalg.vecdot(eigvecs[0, :], fx_eigvals * eigvecs[0, :]) + + return quadform + + +def _bidiagonal_dense(d, e): + diag = linalg.diagonal_matrix(d) + offdiag = linalg.diagonal_matrix(e, 1) + return diag + offdiag + + +def _eigh_tridiag(diag, off_diag): + # todo: once jax supports eigh_tridiagonal(eigvals_only=False), + # use it here. Until then: an eigen-decomposition of size (order + 1) + # does not hurt too much... + diag = linalg.diagonal_matrix(diag) + offdiag1 = linalg.diagonal_matrix(off_diag, -1) + offdiag2 = linalg.diagonal_matrix(off_diag, 1) + dense_matrix = diag + offdiag1 + offdiag2 + eigvals, eigvecs = linalg.eigh(dense_matrix) + return eigvals, eigvecs diff --git a/matfree/hutchinson.py b/matfree/hutchinson.py index 175d3fe..97b3935 100644 --- a/matfree/hutchinson.py +++ b/matfree/hutchinson.py @@ -1,4 +1,4 @@ -"""Hutchinson-style estimation.""" +"""Stochastic estimation of traces, diagonals, and more.""" from matfree.backend import func, linalg, np, prng, tree_util diff --git a/matfree/lanczos.py b/matfree/lanczos.py index 4f0b3aa..9bcc89b 100644 --- a/matfree/lanczos.py +++ b/matfree/lanczos.py @@ -1,171 +1,16 @@ -"""All things Lanczos' algorithm. - -This includes -stochastic Lanczos quadrature (extending the integrands -in [hutchinson][matfree.hutchinson] to those that implement -stochastic Lanczos quadrature), -Lanczos-implementations of matrix-function-vector products, -and various Lanczos-decompositions of matrices. - -Examples --------- ->>> import jax.random ->>> import jax.numpy as jnp ->>> ->>> jnp.set_printoptions(1) ->>> ->>> M = jax.random.normal(jax.random.PRNGKey(1), shape=(10, 10)) ->>> A = M.T @ M ->>> v = jax.random.normal(jax.random.PRNGKey(2), shape=(10,)) ->>> ->>> # Compute a matrix-logarithm with Lanczos' algorithm ->>> matfun_vec = funm_vector_product_spd(jnp.log, 4, lambda s: A @ s) ->>> matfun_vec(v) -Array([-4. , -2.1, -2.7, -1.9, -1.3, -3.5, -0.5, -0.1, 0.3, 1.5], dtype=float32) -""" - -from matfree.backend import containers, control_flow, func, linalg, np, tree_util -from matfree.backend.typing import Array, Callable, Tuple - - -def integrand_spd_logdet(order, matvec, /): - """Construct the integrand for the log-determinant. - - This function assumes a symmetric, positive definite matrix. - """ - return integrand_spd(np.log, order, matvec) - - -def integrand_spd(matfun, order, matvec, /): - """Quadratic form for stochastic Lanczos quadrature. - - This function assumes a symmetric, positive definite matrix. - """ - - def quadform(v0, *parameters): - v0_flat, v_unflatten = tree_util.ravel_pytree(v0) - length = linalg.vector_norm(v0_flat) - v0_flat /= length - - def matvec_flat(v_flat, *p): - v = v_unflatten(v_flat) - Av = matvec(v, *p) - flat, unflatten = tree_util.ravel_pytree(Av) - return flat - - algorithm = alg_tridiag_full_reortho(matvec_flat, order) - _, (diag, off_diag) = algorithm(v0_flat, *parameters) - eigvals, eigvecs = _eigh_tridiag(diag, off_diag) - - # Since Q orthogonal (orthonormal) to v0, Q v = Q[0], - # and therefore (Q v)^T f(D) (Qv) = Q[0] * f(diag) * Q[0] - fx_eigvals = func.vmap(matfun)(eigvals) - return length**2 * linalg.vecdot(eigvecs[0, :], fx_eigvals * eigvecs[0, :]) - - return quadform - - -def integrand_product_logdet(depth, matvec, vecmat, /): - r"""Construct the integrand for the log-determinant of a matrix-product. - - Here, "product" refers to $X = A^\top A$. - """ - return integrand_product(np.log, depth, matvec, vecmat) - - -def integrand_product_schatten_norm(power, depth, matvec, vecmat, /): - r"""Construct the integrand for the p-th power of the Schatten-p norm.""" - - def matfun(x): - """Matrix-function for Schatten-p norms.""" - return x ** (power / 2) - - return integrand_product(matfun, depth, matvec, vecmat) - - -def integrand_product(matfun, depth, matvec, vecmat, /): - r"""Construct the integrand for the trace of a function of a matrix-product. - - Instead of the trace of a function of a matrix, - compute the trace of a function of the product of matrices. - Here, "product" refers to $X = A^\top A$. - """ - - def quadform(v0, *parameters): - v0_flat, v_unflatten = tree_util.ravel_pytree(v0) - length = linalg.vector_norm(v0_flat) - v0_flat /= length - - def matvec_flat(v_flat, *p): - v = v_unflatten(v_flat) - Av = matvec(v, *p) - flat, unflatten = tree_util.ravel_pytree(Av) - return flat, tree_util.partial_pytree(unflatten) - - w0_flat, w_unflatten = func.eval_shape(matvec_flat, v0_flat) - matrix_shape = (*np.shape(w0_flat), *np.shape(v0_flat)) - - def vecmat_flat(w_flat): - w = w_unflatten(w_flat) - wA = vecmat(w, *parameters) - return tree_util.ravel_pytree(wA)[0] - - # Decompose into orthogonal-bidiag-orthogonal - algorithm = alg_bidiag_full_reortho( - lambda v: matvec_flat(v)[0], vecmat_flat, depth, matrix_shape=matrix_shape - ) - output = algorithm(v0_flat, *parameters) - u, (d, e), vt, _ = output - - # Compute SVD of factorisation - B = _bidiagonal_dense(d, e) - _, S, Vt = linalg.svd(B, full_matrices=False) - - # Since Q orthogonal (orthonormal) to v0, Q v = Q[0], - # and therefore (Q v)^T f(D) (Qv) = Q[0] * f(diag) * Q[0] - eigvals, eigvecs = S**2, Vt.T - fx_eigvals = func.vmap(matfun)(eigvals) - return length**2 * linalg.vecdot(eigvecs[0, :], fx_eigvals * eigvecs[0, :]) - - return quadform - - -def _bidiagonal_dense(d, e): - diag = linalg.diagonal_matrix(d) - offdiag = linalg.diagonal_matrix(e, 1) - return diag + offdiag - - -def funm_vector_product_spd(matfun, order, matvec, /): - """Implement a matrix-function-vector product via Lanczos' algorithm. +"""Lanczos-style matrix decompositions. - This algorithm uses Lanczos' tridiagonalisation with full re-orthogonalisation - and therefore applies only to symmetric, positive definite matrices. - """ - algorithm = alg_tridiag_full_reortho(matvec, order) - - def estimate(vec, *parameters): - length = linalg.vector_norm(vec) - vec /= length - basis, (diag, off_diag) = algorithm(vec, *parameters) - eigvals, eigvecs = _eigh_tridiag(diag, off_diag) - - fx_eigvals = func.vmap(matfun)(eigvals) - return length * (basis.T @ (eigvecs @ (fx_eigvals * eigvecs[0, :]))) - - return estimate +This module includes various Lanczos-decompositions of matrices +(tridiagonalisation, bidiagonalisation, etc.). +For stochastic Lanczos quadrature, see +[matfree.funm_trace][matfree.funm_trace]. +For matrix-function-vector products, see +[matfree.funm][matfree.funm]. +""" -def _eigh_tridiag(diag, off_diag): - # todo: once jax supports eigh_tridiagonal(eigvals_only=False), - # use it here. Until then: an eigen-decomposition of size (order + 1) - # does not hurt too much... - diag = linalg.diagonal_matrix(diag) - offdiag1 = linalg.diagonal_matrix(off_diag, -1) - offdiag2 = linalg.diagonal_matrix(off_diag, 1) - dense_matrix = diag + offdiag1 + offdiag2 - eigvals, eigvecs = linalg.eigh(dense_matrix) - return eigvals, eigvecs +from matfree.backend import containers, control_flow, func, linalg, np +from matfree.backend.typing import Array, Callable, Tuple def svd_approx( @@ -410,3 +255,21 @@ def body_fun(_, s): result = control_flow.fori_loop(lower, upper, body_fun=body_fun, init_val=init_val) return extract(result) + + +def _bidiagonal_dense(d, e): + diag = linalg.diagonal_matrix(d) + offdiag = linalg.diagonal_matrix(e, 1) + return diag + offdiag + + +def _eigh_tridiag(diag, off_diag): + # todo: once jax supports eigh_tridiagonal(eigvals_only=False), + # use it here. Until then: an eigen-decomposition of size (order + 1) + # does not hurt too much... + diag = linalg.diagonal_matrix(diag) + offdiag1 = linalg.diagonal_matrix(off_diag, -1) + offdiag2 = linalg.diagonal_matrix(off_diag, 1) + dense_matrix = diag + offdiag1 + offdiag2 + eigvals, eigvecs = linalg.eigh(dense_matrix) + return eigvals, eigvecs diff --git a/tests/test_polynomial/test_funm_vector_product_chebyshev.py b/tests/test_funm/test_funm_chebyshev.py similarity index 84% rename from tests/test_polynomial/test_funm_vector_product_chebyshev.py rename to tests/test_funm/test_funm_chebyshev.py index f0b6946..d232ec5 100644 --- a/tests/test_polynomial/test_funm_vector_product_chebyshev.py +++ b/tests/test_funm/test_funm_chebyshev.py @@ -1,10 +1,10 @@ """Test matrix-polynomial-vector algorithms via Chebyshev's recursion.""" -from matfree import polynomial, test_util +from matfree import funm, test_util from matfree.backend import linalg, np, prng -def test_funm_vector_product_chebyshev(n=12): +def test_funm_chebyshev(n=12): """Test matrix-polynomial-vector algorithms via Chebyshev's recursion.""" # Create a test-problem: matvec, matrix function, # vector, and parameters (a matrix). @@ -27,7 +27,7 @@ def fun(x): # Create an implementation of the Chebyshev-algorithm order = 6 - matfun_vec = polynomial.funm_vector_product_chebyshev(fun, order, matvec) + matfun_vec = funm.funm_chebyshev(fun, order, matvec) # Compute the matrix-function vector product received = matfun_vec(v, matrix) diff --git a/tests/test_lanczos/test_funm_vector_product.py b/tests/test_funm/test_funm_lanczos_spd.py similarity index 78% rename from tests/test_lanczos/test_funm_vector_product.py rename to tests/test_funm/test_funm_lanczos_spd.py index 69d5fa1..0389ea8 100644 --- a/tests/test_lanczos/test_funm_vector_product.py +++ b/tests/test_funm/test_funm_lanczos_spd.py @@ -1,10 +1,10 @@ """Test matrix-function-vector products via Lanczos' algorithm.""" -from matfree import lanczos, test_util +from matfree import funm, test_util from matfree.backend import linalg, np, prng -def test_funm_vector_product(n=11): +def test_funm_lanczos_spd(n=11): """Test matrix-function-vector products via Lanczos' algorithm.""" # Create a test-problem: matvec, matrix function, # vector, and parameters (a matrix). @@ -12,8 +12,6 @@ def test_funm_vector_product(n=11): def matvec(x, p): return p @ x - # todo: write a test for matfun=np.inv, - # because this application seems to be brittle def fun(x): return np.sin(x) @@ -29,6 +27,6 @@ def fun(x): # Compute the matrix-function vector product order = 6 - matfun_vec = lanczos.funm_vector_product_spd(fun, order, matvec) + matfun_vec = funm.funm_lanczos_spd(fun, order, matvec) received = matfun_vec(v, matrix) assert np.allclose(expected, received, atol=1e-6) diff --git a/tests/test_lanczos/test_integrand_logdet_product.py b/tests/test_funm_trace/test_integrand_logdet_product.py similarity index 93% rename from tests/test_lanczos/test_integrand_logdet_product.py rename to tests/test_funm_trace/test_integrand_logdet_product.py index 402a2a4..83a7431 100644 --- a/tests/test_lanczos/test_integrand_logdet_product.py +++ b/tests/test_funm_trace/test_integrand_logdet_product.py @@ -1,6 +1,6 @@ """Test stochastic Lanczos quadrature for log-determinants of matrix-products.""" -from matfree import hutchinson, lanczos, test_util +from matfree import funm_trace, hutchinson, test_util from matfree.backend import linalg, np, prng, testing @@ -31,7 +31,7 @@ def vecmat(x): x_like = {"fx": np.ones((ncols,), dtype=float)} fun = hutchinson.sampler_normal(x_like, num=400) - problem = lanczos.integrand_product_logdet(order, matvec, vecmat) + problem = funm_trace.integrand_product_logdet(order, matvec, vecmat) estimate = hutchinson.hutchinson(problem, fun) received = estimate(key) @@ -53,7 +53,7 @@ def test_logdet_product_exact_for_full_order_lanczos(n): # Set up max-order Lanczos approximation inside SLQ for the matrix-logarithm order = n - 1 - integrand = lanczos.integrand_product_logdet( + integrand = funm_trace.integrand_product_logdet( order, lambda v: A @ v, lambda v: v @ A ) diff --git a/tests/test_lanczos/test_integrand_logdet_spd.py b/tests/test_funm_trace/test_integrand_logdet_spd.py similarity index 92% rename from tests/test_lanczos/test_integrand_logdet_spd.py rename to tests/test_funm_trace/test_integrand_logdet_spd.py index a5629a0..77102a3 100644 --- a/tests/test_lanczos/test_integrand_logdet_spd.py +++ b/tests/test_funm_trace/test_integrand_logdet_spd.py @@ -1,6 +1,6 @@ """Tests for Lanczos functionality.""" -from matfree import hutchinson, lanczos, test_util +from matfree import funm_trace, hutchinson, test_util from matfree.backend import linalg, np, prng, testing @@ -29,7 +29,7 @@ def matvec(x): key = prng.prng_key(1) args_like = {"fx": np.ones((n,), dtype=float)} sampler = hutchinson.sampler_normal(args_like, num=10) - integrand = lanczos.integrand_spd_logdet(order, matvec) + integrand = funm_trace.integrand_spd_logdet(order, matvec) estimate = hutchinson.hutchinson(integrand, sampler) received = estimate(key) @@ -49,7 +49,7 @@ def test_logdet_spd_exact_for_full_order_lanczos(n): # Set up max-order Lanczos approximation inside SLQ for the matrix-logarithm order = n - 1 - integrand = lanczos.integrand_spd_logdet(order, lambda v: A @ v) + integrand = funm_trace.integrand_spd_logdet(order, lambda v: A @ v) # Construct a vector without that does not have expected 2-norm equal to "dim" x = prng.normal(prng.prng_key(seed=1), shape=(n,)) + 10 diff --git a/tests/test_lanczos/test_integrand_schatten_norm.py b/tests/test_funm_trace/test_integrand_schatten_norm.py similarity index 92% rename from tests/test_lanczos/test_integrand_schatten_norm.py rename to tests/test_funm_trace/test_integrand_schatten_norm.py index 6516890..af45858 100644 --- a/tests/test_lanczos/test_integrand_schatten_norm.py +++ b/tests/test_funm_trace/test_integrand_schatten_norm.py @@ -1,6 +1,6 @@ """Test stochastic Lanczos quadrature for Schatten-p-norms.""" -from matfree import hutchinson, lanczos, test_util +from matfree import funm_trace, hutchinson, test_util from matfree.backend import linalg, np, prng, testing @@ -27,7 +27,7 @@ def test_schatten_norm(A, order, power): _, ncols = np.shape(A) args_like = np.ones((ncols,), dtype=float) sampler = hutchinson.sampler_normal(args_like, num=500) - integrand = lanczos.integrand_product_schatten_norm( + integrand = funm_trace.integrand_product_schatten_norm( power, order, lambda v: A @ v, lambda v: A.T @ v ) estimate = hutchinson.hutchinson(integrand, sampler) diff --git a/tutorials/1_log_determinants.py b/tutorials/1_log_determinants.py index c73e8ef..d7bf573 100644 --- a/tutorials/1_log_determinants.py +++ b/tutorials/1_log_determinants.py @@ -7,7 +7,7 @@ import jax import jax.numpy as jnp -from matfree import hutchinson, lanczos +from matfree import funm_trace, hutchinson # Set up a matrix. @@ -27,7 +27,7 @@ def matvec(x): # Estimate log-determinants with stochastic Lanczos quadrature. order = 3 -problem = lanczos.integrand_spd_logdet(order, matvec) +problem = funm_trace.integrand_spd_logdet(order, matvec) sampler = hutchinson.sampler_normal(x_like, num=1_000) estimator = hutchinson.hutchinson(problem, sample_fun=sampler) logdet = estimator(jax.random.PRNGKey(1)) @@ -58,7 +58,7 @@ def vecmat_left(x): order = 3 -problem = lanczos.integrand_product_logdet(order, matvec_right, vecmat_left) +problem = funm_trace.integrand_product_logdet(order, matvec_right, vecmat_left) sampler = hutchinson.sampler_normal(x_like, num=1_000) estimator = hutchinson.hutchinson(problem, sample_fun=sampler) logdet = estimator(jax.random.PRNGKey(1)) diff --git a/tutorials/2_pytree_logdeterminants.py b/tutorials/2_pytree_logdeterminants.py index 30f7ccc..bed7b54 100644 --- a/tutorials/2_pytree_logdeterminants.py +++ b/tutorials/2_pytree_logdeterminants.py @@ -8,7 +8,7 @@ import jax import jax.numpy as jnp -from matfree import hutchinson, lanczos +from matfree import funm_trace, hutchinson # Create a test-problem: a function that maps a pytree (dict) to a pytree (tuple). # Its (regularised) Gauss--Newton Hessian shall be the matrix-vector product @@ -53,7 +53,7 @@ def fun(fx, /): matvec = make_matvec(alpha=0.1) order = 3 -integrand = lanczos.integrand_spd_logdet(order, matvec) +integrand = funm_trace.integrand_spd_logdet(order, matvec) sample_fun = hutchinson.sampler_normal(f0, num=10) estimator = hutchinson.hutchinson(integrand, sample_fun=sample_fun) key = jax.random.PRNGKey(1)