From b5fb4ae90718eea5021518bbe6e60833132731b2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Nicholas=20Kr=C3=A4mer?= Date: Fri, 31 May 2024 09:13:37 +0200 Subject: [PATCH] Implement matrix functions by combining matrix decompositions with dense matrix functions (#201) * Implement funm_via_lanczos by assuming a readily assembled Lanczos decomposition * Implement Lanczos-funm's via combining decompositions with various small funm's * Update some docstrings * Rename integrand_ to integrand_funm_ to prepare moving the code to funm.py * Move the content of stochtrace_funm.py to funm.py because all internal logic is the same --- matfree/backend/linalg.py | 4 + matfree/decomp.py | 5 +- matfree/funm.py | 203 ++++++++++++++++-- matfree/stochtrace_funm.py | 132 ------------ tests/test_funm/test_funm_lanczos_sym.py | 12 +- .../test_integrand_funm_logdet_product.py} | 6 +- .../test_integrand_funm_logdet_sym.py} | 6 +- .../test_integrand_funm_schatten_norm.py} | 4 +- tutorials/1_log_determinants.py | 10 +- tutorials/2_pytree_logdeterminants.py | 4 +- 10 files changed, 214 insertions(+), 172 deletions(-) delete mode 100644 matfree/stochtrace_funm.py rename tests/{test_stochtrace_funm/test_integrand_logdet_product.py => test_funm/test_integrand_funm_logdet_product.py} (92%) rename tests/{test_stochtrace_funm/test_integrand_logdet_sym.py => test_funm/test_integrand_funm_logdet_sym.py} (91%) rename tests/{test_stochtrace_funm/test_integrand_schatten_norm.py => test_funm/test_integrand_funm_schatten_norm.py} (91%) diff --git a/matfree/backend/linalg.py b/matfree/backend/linalg.py index 89310e0..8ca87ec 100644 --- a/matfree/backend/linalg.py +++ b/matfree/backend/linalg.py @@ -77,3 +77,7 @@ def solve(A, b, /): def cg(Av, b, /): return jax.scipy.sparse.linalg.cg(Av, b) + + +def funm_schur(A, f, /): + return jax.scipy.linalg.funm(A, f) diff --git a/matfree/decomp.py b/matfree/decomp.py index 98a19d6..fdbed56 100644 --- a/matfree/decomp.py +++ b/matfree/decomp.py @@ -3,9 +3,8 @@ This module includes various Lanczos-decompositions of matrices (tri-diagonal, bi-diagonal, etc.). -For stochastic Lanczos quadrature, see -[matfree.stochtrace_funm][matfree.stochtrace_funm]. -For matrix-function-vector products, see +For stochastic Lanczos quadrature and +matrix-function-vector products, see [matfree.funm][matfree.funm]. """ diff --git a/matfree/funm.py b/matfree/funm.py index 0d37ab2..d606db3 100644 --- a/matfree/funm.py +++ b/matfree/funm.py @@ -1,9 +1,27 @@ """Matrix-free implementations of functions of matrices. +This includes matrix-function-vector products + +$$ +(f, v, p) \\mapsto f(A(p))v +$$ + +as well as matrix-function extensions for stochastic trace estimation, +which provide + +$$ +(f, v, p) \\mapsto v^\\top f(A(p))v. +$$ + +Plug these integrands into +[matfree.stochtrace.estimator][matfree.stochtrace.estimator]. + + Examples -------- >>> import jax.random >>> import jax.numpy as jnp +>>> from matfree import decomp >>> >>> jnp.set_printoptions(1) >>> @@ -12,13 +30,15 @@ >>> v = jax.random.normal(jax.random.PRNGKey(2), shape=(10,)) >>> >>> # Compute a matrix-logarithm with Lanczos' algorithm ->>> matfun_vec = funm_lanczos_sym(jnp.log, 4, lambda s: A @ s) +>>> matfun = dense_funm_sym_eigh(jnp.log) +>>> tridiag = decomp.tridiag_sym(lambda s: A @ s, 4) +>>> matfun_vec = funm_lanczos_sym(matfun, tridiag) >>> matfun_vec(v) Array([-4.1, -1.3, -2.2, -2.1, -1.2, -3.3, -0.2, 0.3, 0.7, 0.9], dtype=float32) """ from matfree import decomp -from matfree.backend import containers, control_flow, func, linalg, np +from matfree.backend import containers, control_flow, func, linalg, np, tree_util from matfree.backend.typing import Array, Callable @@ -94,36 +114,185 @@ def matvec(vec, *parameters): return matvec -# todo: if we pass decomp.tridiag_sym instead of order & matvec, -# the user gets more control over questions like reorthogonalisation -def funm_lanczos_sym(matfun: Callable, order: int, matvec: Callable, /) -> Callable: +def funm_lanczos_sym(dense_funm: Callable, tridiag_sym: Callable, /) -> Callable: """Implement a matrix-function-vector product via Lanczos' tridiagonalisation. This algorithm uses Lanczos' tridiagonalisation and therefore applies only to symmetric matrices. + + Parameters + ---------- + dense_funm + An implementation of a function of a dense matrix. + For example, the output of + [funm.dense_funm_sym_eigh][matfree.funm.dense_funm_sym_eigh] + [funm.dense_funm_schur][matfree.funm.dense_funm_schur] + tridiag_sym + An implementation of tridiagonalisation. + E.g., the output of + [decomp.tridiag_sym][matfree.decomp.tridiag_sym]. """ - algorithm = decomp.tridiag_sym(matvec, order) def estimate(vec, *parameters): length = linalg.vector_norm(vec) vec /= length - (basis, (diag, off_diag)), _ = algorithm(vec, *parameters) - eigvals, eigvecs = _eigh_tridiag_sym(diag, off_diag) + (basis, (diag, off_diag)), _ = tridiag_sym(vec, *parameters) + matrix = _todense_tridiag_sym(diag, off_diag) - fx_eigvals = func.vmap(matfun)(eigvals) - return length * (basis.T @ (eigvecs @ (fx_eigvals * eigvecs[0, :]))) + funm = dense_funm(matrix) + e1 = np.eye(len(matrix))[0, :] + return length * (basis.T @ funm @ e1) return estimate -def _eigh_tridiag_sym(diag, off_diag): - # todo: once jax supports eigh_tridiagonal(eigvals_only=False), - # use it here. Until then: an eigen-decomposition of size (order + 1) - # does not hurt too much... +def integrand_funm_sym_logdet(order, matvec, /): + """Construct the integrand for the log-determinant. + + This function assumes a symmetric, positive definite matrix. + """ + return integrand_funm_sym(np.log, order, matvec) + + +def integrand_funm_sym(matfun, order, matvec, /): + """Construct the integrand for matrix-function-trace estimation. + + This function assumes a symmetric matrix. + """ + # todo: if we ask the user to flatten their matvecs, + # then we can give this code the same API as funm_lanczos_sym. + dense_funm = dense_funm_sym_eigh(matfun) + + def quadform(v0, *parameters): + v0_flat, v_unflatten = tree_util.ravel_pytree(v0) + length = linalg.vector_norm(v0_flat) + v0_flat /= length + + def matvec_flat(v_flat, *p): + v = v_unflatten(v_flat) + Av = matvec(v, *p) + flat, unflatten = tree_util.ravel_pytree(Av) + return flat + + algorithm = decomp.tridiag_sym(matvec_flat, order) + (_, (diag, off_diag)), _ = algorithm(v0_flat, *parameters) + + dense = _todense_tridiag_sym(diag, off_diag) + fA = dense_funm(dense) + e1 = np.eye(len(fA))[0, :] + return length**2 * linalg.inner(e1, fA @ e1) + + return quadform + + +def integrand_funm_product_logdet(depth, matvec, vecmat, /): + r"""Construct the integrand for the log-determinant of a matrix-product. + + Here, "product" refers to $X = A^\top A$. + """ + return integrand_funm_product(np.log, depth, matvec, vecmat) + + +def integrand_funm_product_schatten_norm(power, depth, matvec, vecmat, /): + r"""Construct the integrand for the $p$-th power of the Schatten-p norm.""" + + def matfun(x): + """Matrix-function for Schatten-p norms.""" + return x ** (power / 2) + + return integrand_funm_product(matfun, depth, matvec, vecmat) + + +def integrand_funm_product(matfun, depth, matvec, vecmat, /): + r"""Construct the integrand for matrix-function-trace estimation. + + Instead of the trace of a function of a matrix, + compute the trace of a function of the product of matrices. + Here, "product" refers to $X = A^\top A$. + """ + + def quadform(v0, *parameters): + v0_flat, v_unflatten = tree_util.ravel_pytree(v0) + length = linalg.vector_norm(v0_flat) + v0_flat /= length + + def matvec_flat(v_flat, *p): + v = v_unflatten(v_flat) + Av = matvec(v, *p) + flat, unflatten = tree_util.ravel_pytree(Av) + return flat, tree_util.partial_pytree(unflatten) + + w0_flat, w_unflatten = func.eval_shape(matvec_flat, v0_flat) + matrix_shape = (*np.shape(w0_flat), *np.shape(v0_flat)) + + def vecmat_flat(w_flat): + w = w_unflatten(w_flat) + wA = vecmat(w, *parameters) + return tree_util.ravel_pytree(wA)[0] + + # Decompose into orthogonal-bidiag-orthogonal + algorithm = decomp.bidiag( + lambda v: matvec_flat(v)[0], vecmat_flat, depth, matrix_shape=matrix_shape + ) + output = algorithm(v0_flat, *parameters) + u, (d, e), vt, _ = output + + # Compute SVD of factorisation + B = _todense_bidiag(d, e) + + # todo: turn the following lines into dense_funm_svd() + _, S, Vt = linalg.svd(B, full_matrices=False) + + # Since Q orthogonal (orthonormal) to v0, Q v = Q[0], + # and therefore (Q v)^T f(D) (Qv) = Q[0] * f(diag) * Q[0] + eigvals, eigvecs = S**2, Vt.T + fx_eigvals = func.vmap(matfun)(eigvals) + return length**2 * linalg.inner(eigvecs[0, :], fx_eigvals * eigvecs[0, :]) + + return quadform + + +def dense_funm_sym_eigh(matfun): + """Implement dense matrix-functions via symmetric eigendecompositions. + + Use it to construct one of the matrix-free matrix-function implementations, + e.g. [matfree.funm.funm_lanczos_sym][matfree.funm.funm_lanczos_sym]. + """ + + def fun(dense_matrix): + eigvals, eigvecs = linalg.eigh(dense_matrix) + fx_eigvals = func.vmap(matfun)(eigvals) + return eigvecs @ linalg.diagonal(fx_eigvals) @ eigvecs.T + + return fun + + +def dense_funm_schur(matfun): + """Implement dense matrix-functions via symmetric Schur decompositions. + + Use it to construct one of the matrix-free matrix-function implementations, + e.g. [matfree.funm.funm_lanczos_sym][matfree.funm.funm_lanczos_sym]. + """ + + def fun(dense_matrix): + return linalg.funm_schur(dense_matrix, matfun) + + return fun + + +# todo: if we move this logic to the decomposition algorithms +# (e.g. with a materalize=True flag in the decomposition construction), +# then all trace_of_funm implementation reduce to very few lines. + + +def _todense_tridiag_sym(diag, off_diag): diag = linalg.diagonal_matrix(diag) offdiag1 = linalg.diagonal_matrix(off_diag, -1) offdiag2 = linalg.diagonal_matrix(off_diag, 1) - dense_matrix = diag + offdiag1 + offdiag2 + return diag + offdiag1 + offdiag2 + - eigvals, eigvecs = linalg.eigh(dense_matrix) - return eigvals, eigvecs +def _todense_bidiag(d, e): + diag = linalg.diagonal_matrix(d) + offdiag = linalg.diagonal_matrix(e, 1) + return diag + offdiag diff --git a/matfree/stochtrace_funm.py b/matfree/stochtrace_funm.py deleted file mode 100644 index 1aa9606..0000000 --- a/matfree/stochtrace_funm.py +++ /dev/null @@ -1,132 +0,0 @@ -"""Stochastic estimation of traces of **functions of matrices**. - -This module extends [matfree.stochtrace][matfree.stochtrace]. - -""" - -from matfree import decomp -from matfree.backend import func, linalg, np, tree_util - -# todo: currently, all dense matrix-functions are computed -# via eigh(). But for e.g. log and exp, we might want to do -# something else. - - -def integrand_sym_logdet(order, matvec, /): - """Construct the integrand for the log-determinant. - - This function assumes a symmetric, positive definite matrix. - """ - return integrand_sym(np.log, order, matvec) - - -def integrand_sym(matfun, order, matvec, /): - """Construct the integrand for matrix-function-trace estimation. - - This function assumes a symmetric matrix. - """ - - def quadform(v0, *parameters): - v0_flat, v_unflatten = tree_util.ravel_pytree(v0) - length = linalg.vector_norm(v0_flat) - v0_flat /= length - - def matvec_flat(v_flat, *p): - v = v_unflatten(v_flat) - Av = matvec(v, *p) - flat, unflatten = tree_util.ravel_pytree(Av) - return flat - - algorithm = decomp.tridiag_sym(matvec_flat, order) - (_, (diag, off_diag)), _ = algorithm(v0_flat, *parameters) - eigvals, eigvecs = _eigh_tridiag_sym(diag, off_diag) - - # Since Q orthogonal (orthonormal) to v0, Q v = Q[0], - # and therefore (Q v)^T f(D) (Qv) = Q[0] * f(diag) * Q[0] - fx_eigvals = func.vmap(matfun)(eigvals) - return length**2 * linalg.inner(eigvecs[0, :], fx_eigvals * eigvecs[0, :]) - - return quadform - - -def integrand_product_logdet(depth, matvec, vecmat, /): - r"""Construct the integrand for the log-determinant of a matrix-product. - - Here, "product" refers to $X = A^\top A$. - """ - return integrand_product(np.log, depth, matvec, vecmat) - - -def integrand_product_schatten_norm(power, depth, matvec, vecmat, /): - r"""Construct the integrand for the p-th power of the Schatten-p norm.""" - - def matfun(x): - """Matrix-function for Schatten-p norms.""" - return x ** (power / 2) - - return integrand_product(matfun, depth, matvec, vecmat) - - -def integrand_product(matfun, depth, matvec, vecmat, /): - """Construct the integrand for matrix-function-trace estimation. - - Instead of the trace of a function of a matrix, - compute the trace of a function of the product of matrices. - Here, "product" refers to $X = A^\top A$. - """ - - def quadform(v0, *parameters): - v0_flat, v_unflatten = tree_util.ravel_pytree(v0) - length = linalg.vector_norm(v0_flat) - v0_flat /= length - - def matvec_flat(v_flat, *p): - v = v_unflatten(v_flat) - Av = matvec(v, *p) - flat, unflatten = tree_util.ravel_pytree(Av) - return flat, tree_util.partial_pytree(unflatten) - - w0_flat, w_unflatten = func.eval_shape(matvec_flat, v0_flat) - matrix_shape = (*np.shape(w0_flat), *np.shape(v0_flat)) - - def vecmat_flat(w_flat): - w = w_unflatten(w_flat) - wA = vecmat(w, *parameters) - return tree_util.ravel_pytree(wA)[0] - - # Decompose into orthogonal-bidiag-orthogonal - algorithm = decomp.bidiag( - lambda v: matvec_flat(v)[0], vecmat_flat, depth, matrix_shape=matrix_shape - ) - output = algorithm(v0_flat, *parameters) - u, (d, e), vt, _ = output - - # Compute SVD of factorisation - B = _bidiagonal_dense(d, e) - _, S, Vt = linalg.svd(B, full_matrices=False) - - # Since Q orthogonal (orthonormal) to v0, Q v = Q[0], - # and therefore (Q v)^T f(D) (Qv) = Q[0] * f(diag) * Q[0] - eigvals, eigvecs = S**2, Vt.T - fx_eigvals = func.vmap(matfun)(eigvals) - return length**2 * linalg.inner(eigvecs[0, :], fx_eigvals * eigvecs[0, :]) - - return quadform - - -def _bidiagonal_dense(d, e): - diag = linalg.diagonal_matrix(d) - offdiag = linalg.diagonal_matrix(e, 1) - return diag + offdiag - - -def _eigh_tridiag_sym(diag, off_diag): - # todo: once jax supports eigh_tridiagonal(eigvals_only=False), - # use it here. Until then: an eigen-decomposition of size (order + 1) - # does not hurt too much... - diag = linalg.diagonal_matrix(diag) - offdiag1 = linalg.diagonal_matrix(off_diag, -1) - offdiag2 = linalg.diagonal_matrix(off_diag, 1) - dense_matrix = diag + offdiag1 + offdiag2 - eigvals, eigvecs = linalg.eigh(dense_matrix) - return eigvals, eigvecs diff --git a/tests/test_funm/test_funm_lanczos_sym.py b/tests/test_funm/test_funm_lanczos_sym.py index 39db2db..5287f78 100644 --- a/tests/test_funm/test_funm_lanczos_sym.py +++ b/tests/test_funm/test_funm_lanczos_sym.py @@ -1,10 +1,11 @@ """Test matrix-function-vector products via Lanczos' algorithm.""" -from matfree import funm, test_util -from matfree.backend import linalg, np, prng +from matfree import decomp, funm, test_util +from matfree.backend import linalg, np, prng, testing -def test_funm_lanczos_sym(n=11): +@testing.parametrize("dense_funm", [funm.dense_funm_sym_eigh, funm.dense_funm_schur]) +def test_funm_lanczos_sym_matches_eigh_implementation(dense_funm, n=11): """Test matrix-function-vector products via Lanczos' algorithm.""" # Create a test-problem: matvec, matrix function, # vector, and parameters (a matrix). @@ -26,7 +27,8 @@ def fun(x): expected = log_matrix @ v # Compute the matrix-function vector product - order = 6 - matfun_vec = funm.funm_lanczos_sym(fun, order, matvec) + dense_funm = dense_funm(fun) + lanczos = decomp.tridiag_sym(matvec, 6) + matfun_vec = funm.funm_lanczos_sym(dense_funm, lanczos) received = matfun_vec(v, matrix) assert np.allclose(expected, received, atol=1e-6) diff --git a/tests/test_stochtrace_funm/test_integrand_logdet_product.py b/tests/test_funm/test_integrand_funm_logdet_product.py similarity index 92% rename from tests/test_stochtrace_funm/test_integrand_logdet_product.py rename to tests/test_funm/test_integrand_funm_logdet_product.py index 19f0aec..abd0b0e 100644 --- a/tests/test_stochtrace_funm/test_integrand_logdet_product.py +++ b/tests/test_funm/test_integrand_funm_logdet_product.py @@ -1,6 +1,6 @@ """Test stochastic Lanczos quadrature for log-determinants of matrix-products.""" -from matfree import stochtrace, stochtrace_funm, test_util +from matfree import funm, stochtrace, test_util from matfree.backend import linalg, np, prng, testing @@ -31,7 +31,7 @@ def vecmat(x): x_like = {"fx": np.ones((ncols,), dtype=float)} fun = stochtrace.sampler_normal(x_like, num=400) - problem = stochtrace_funm.integrand_product_logdet(order, matvec, vecmat) + problem = funm.integrand_funm_product_logdet(order, matvec, vecmat) estimate = stochtrace.estimator(problem, fun) received = estimate(key) @@ -53,7 +53,7 @@ def test_logdet_product_exact_for_full_order_lanczos(n): # Set up max-order Lanczos approximation inside SLQ for the matrix-logarithm order = n - 1 - integrand = stochtrace_funm.integrand_product_logdet( + integrand = funm.integrand_funm_product_logdet( order, lambda v: A @ v, lambda v: v @ A ) diff --git a/tests/test_stochtrace_funm/test_integrand_logdet_sym.py b/tests/test_funm/test_integrand_funm_logdet_sym.py similarity index 91% rename from tests/test_stochtrace_funm/test_integrand_logdet_sym.py rename to tests/test_funm/test_integrand_funm_logdet_sym.py index a87c48b..c314d6f 100644 --- a/tests/test_stochtrace_funm/test_integrand_logdet_sym.py +++ b/tests/test_funm/test_integrand_funm_logdet_sym.py @@ -1,6 +1,6 @@ """Tests for Lanczos functionality.""" -from matfree import stochtrace, stochtrace_funm, test_util +from matfree import funm, stochtrace, test_util from matfree.backend import linalg, np, prng, testing @@ -29,7 +29,7 @@ def matvec(x): key = prng.prng_key(1) args_like = {"fx": np.ones((n,), dtype=float)} sampler = stochtrace.sampler_normal(args_like, num=10) - integrand = stochtrace_funm.integrand_sym_logdet(order, matvec) + integrand = funm.integrand_funm_sym_logdet(order, matvec) estimate = stochtrace.estimator(integrand, sampler) received = estimate(key) @@ -49,7 +49,7 @@ def test_logdet_spd_exact_for_full_order_lanczos(n): # Set up max-order Lanczos approximation inside SLQ for the matrix-logarithm order = n - 1 - integrand = stochtrace_funm.integrand_sym_logdet(order, lambda v: A @ v) + integrand = funm.integrand_funm_sym_logdet(order, lambda v: A @ v) # Construct a vector without that does not have expected 2-norm equal to "dim" x = prng.normal(prng.prng_key(seed=1), shape=(n,)) + 10 diff --git a/tests/test_stochtrace_funm/test_integrand_schatten_norm.py b/tests/test_funm/test_integrand_funm_schatten_norm.py similarity index 91% rename from tests/test_stochtrace_funm/test_integrand_schatten_norm.py rename to tests/test_funm/test_integrand_funm_schatten_norm.py index f3481af..55644be 100644 --- a/tests/test_stochtrace_funm/test_integrand_schatten_norm.py +++ b/tests/test_funm/test_integrand_funm_schatten_norm.py @@ -1,6 +1,6 @@ """Test stochastic Lanczos quadrature for Schatten-p-norms.""" -from matfree import stochtrace, stochtrace_funm, test_util +from matfree import funm, stochtrace, test_util from matfree.backend import linalg, np, prng, testing @@ -27,7 +27,7 @@ def test_schatten_norm(A, order, power): _, ncols = np.shape(A) args_like = np.ones((ncols,), dtype=float) sampler = stochtrace.sampler_normal(args_like, num=500) - integrand = stochtrace_funm.integrand_product_schatten_norm( + integrand = funm.integrand_funm_product_schatten_norm( power, order, lambda v: A @ v, lambda v: A.T @ v ) estimate = stochtrace.estimator(integrand, sampler) diff --git a/tutorials/1_log_determinants.py b/tutorials/1_log_determinants.py index b480cb3..b268fd6 100644 --- a/tutorials/1_log_determinants.py +++ b/tutorials/1_log_determinants.py @@ -7,7 +7,7 @@ import jax import jax.numpy as jnp -from matfree import stochtrace, stochtrace_funm +from matfree import funm, stochtrace # Set up a matrix. @@ -27,7 +27,7 @@ def matvec(x): # Estimate log-determinants with stochastic Lanczos quadrature. order = 3 -problem = stochtrace_funm.integrand_sym_logdet(order, matvec) +problem = funm.integrand_funm_sym_logdet(order, matvec) sampler = stochtrace.sampler_normal(x_like, num=1_000) estimator = stochtrace.estimator(problem, sampler=sampler) logdet = estimator(jax.random.PRNGKey(1)) @@ -47,18 +47,18 @@ def matvec(x): A /= nrows**2 -def matvec_right(x): +def matvec_r(x): """Compute a matrix-vector product.""" return A @ x -def vecmat_left(x): +def vecmat_l(x): """Compute a vector-matrix product.""" return x @ A order = 3 -problem = stochtrace_funm.integrand_product_logdet(order, matvec_right, vecmat_left) +problem = funm.integrand_funm_product_logdet(order, matvec_r, vecmat_l) sampler = stochtrace.sampler_normal(x_like, num=1_000) estimator = stochtrace.estimator(problem, sampler=sampler) logdet = estimator(jax.random.PRNGKey(1)) diff --git a/tutorials/2_pytree_logdeterminants.py b/tutorials/2_pytree_logdeterminants.py index 0459723..5e238b5 100644 --- a/tutorials/2_pytree_logdeterminants.py +++ b/tutorials/2_pytree_logdeterminants.py @@ -8,7 +8,7 @@ import jax import jax.numpy as jnp -from matfree import stochtrace, stochtrace_funm +from matfree import funm, stochtrace # Create a test-problem: a function that maps a pytree (dict) to a pytree (tuple). # Its (regularised) Gauss--Newton Hessian shall be the matrix-vector product @@ -53,7 +53,7 @@ def fun(fx, /): matvec = make_matvec(alpha=0.1) order = 3 -integrand = stochtrace_funm.integrand_sym_logdet(order, matvec) +integrand = funm.integrand_funm_sym_logdet(order, matvec) sample_fun = stochtrace.sampler_normal(f0, num=10) estimator = stochtrace.estimator(integrand, sampler=sample_fun) key = jax.random.PRNGKey(1)