Skip to content

Commit

Permalink
Move matvec to estimation instead of algorithm-construction to avoid …
Browse files Browse the repository at this point in the history
…future code duplication (#203)

* Move matvec-arguments to estimate-functions

* Update the matrix function code to the new API

* Update the matrix-function code

* Move a todo

* Update the remaining matfree functions to the fun(matvec, v, *p) signature

* Update documentation

* Update the tutorial code

* Format the tutorials
  • Loading branch information
pnkraemer authored May 31, 2024
1 parent 1b164b1 commit fd99af5
Show file tree
Hide file tree
Showing 27 changed files with 161 additions and 160 deletions.
7 changes: 6 additions & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,12 @@ test:
pytest
python -m doctest README.md
python -m doctest matfree/*.py

python tutorials/1_log_determinants.py
python tutorials/2_pytree_logdeterminants.py
python tutorials/3_uncertainty_quantification.py
python tutorials/4_control_variates.py
python tutorials/5_vector_calculus.py
python tutorials/6_low_memory_trace_estimation.py

clean-preview:
git clean -xdn
Expand Down
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -68,14 +68,14 @@ Estimate the trace of the matrix:
>>>
>>> # Set Hutchinson's method up to compute the traces
>>> # (instead of, e.g., diagonals)
>>> integrand = stochtrace.integrand_trace(matvec)
>>> integrand = stochtrace.integrand_trace()
>>>
>>> # Compute an estimator
>>> estimate = stochtrace.estimator(integrand, sampler)

>>> # Estimate
>>> key = jax.random.PRNGKey(1)
>>> trace = jax.jit(estimate)(key)
>>> trace = estimate(matvec, key)
>>>
>>> print(trace)
508.9
Expand Down
55 changes: 28 additions & 27 deletions matfree/decomp.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,7 @@
from matfree.backend.typing import Array, Callable


def tridiag_sym(
matvec, krylov_depth, /, *, reortho: str = "full", custom_vjp: bool = True
):
def tridiag_sym(krylov_depth, /, *, reortho: str = "full", custom_vjp: bool = True):
r"""Construct an implementation of **tridiagonalisation**.
Uses pre-allocation, and full reorthogonalisation if `reortho` is set to `"full"`.
Expand Down Expand Up @@ -46,21 +44,21 @@ def tridiag_sym(
"""

if reortho == "full":
return _tridiag_reortho_full(matvec, krylov_depth, custom_vjp=custom_vjp)
return _tridiag_reortho_full(krylov_depth, custom_vjp=custom_vjp)
if reortho == "none":
return _tridiag_reortho_none(matvec, krylov_depth, custom_vjp=custom_vjp)
return _tridiag_reortho_none(krylov_depth, custom_vjp=custom_vjp)

msg = f"reortho={reortho} unsupported. Choose eiter {'full', 'none'}."
raise ValueError(msg)


def _tridiag_reortho_full(matvec, krylov_depth, /, *, custom_vjp):
def _tridiag_reortho_full(krylov_depth, /, *, custom_vjp):
# Implement via Arnoldi to use the reorthogonalised adjoints.
# The complexity difference is minimal with full reortho.
alg = hessenberg(matvec, krylov_depth, custom_vjp=custom_vjp, reortho="full")
alg = hessenberg(krylov_depth, custom_vjp=custom_vjp, reortho="full")

def estimate(vec, *params):
Q, H, v, _norm = alg(vec, *params)
def estimate(matvec, vec, *params):
Q, H, v, _norm = alg(matvec, vec, *params)

T = 0.5 * (H + H.T)
diags = linalg.diagonal(T, offset=0)
Expand All @@ -72,16 +70,16 @@ def estimate(vec, *params):
return estimate


def _tridiag_reortho_none(matvec, krylov_depth, /, *, custom_vjp):
def estimate(vec, *params):
def _tridiag_reortho_none(krylov_depth, /, *, custom_vjp):
def estimate(matvec, vec, *params):
*values, _ = _tridiag_forward(matvec, krylov_depth, vec, *params)
return values

def estimate_fwd(vec, *params):
value = estimate(vec, *params)
def estimate_fwd(matvec, vec, *params):
value = estimate(matvec, vec, *params)
return value, (value, (linalg.vector_norm(vec), *params))

def estimate_bwd(cache, vjp_incoming):
def estimate_bwd(matvec, cache, vjp_incoming):
# Read incoming gradients and stack related quantities
(dxs, (dalphas, dbetas)), (dx_last, dbeta_last) = vjp_incoming
dxs = np.concatenate((dxs, dx_last[None]))
Expand All @@ -107,7 +105,7 @@ def estimate_bwd(cache, vjp_incoming):
return grads

if custom_vjp:
estimate = func.custom_vjp(estimate)
estimate = func.custom_vjp(estimate, nondiff_argnums=[0])
estimate.defvjp(estimate_fwd, estimate_bwd) # type: ignore

return estimate
Expand Down Expand Up @@ -240,7 +238,6 @@ def _tridiag_adjoint_step(


def hessenberg(
matvec,
krylov_depth,
/,
*,
Expand Down Expand Up @@ -279,18 +276,18 @@ def hessenberg(
msg = f"Unexpected input for {reortho}: either of {reortho_expected} expected."
raise TypeError(msg)

def estimate_public(v, *params):
def estimate(matvec, v, *params):
matvec_convert, aux_args = func.closure_convert(matvec, v, *params)
return estimate_backend(matvec_convert, v, *params, *aux_args)
return _estimate(matvec_convert, v, *params, *aux_args)

def estimate_backend(matvec_convert: Callable, v, *params):
def _estimate(matvec_convert: Callable, v, *params):
reortho_ = reortho_vjp if reortho_vjp != "match" else reortho_vjp
return _hessenberg_forward(
matvec_convert, krylov_depth, v, *params, reortho=reortho_
)

def estimate_fwd(matvec_convert: Callable, v, *params):
outputs = estimate_backend(matvec_convert, v, *params)
outputs = _estimate(matvec_convert, v, *params)
return outputs, (outputs, params)

def estimate_bwd(matvec_convert: Callable, cache, vjp_incoming):
Expand All @@ -312,9 +309,9 @@ def estimate_bwd(matvec_convert: Callable, cache, vjp_incoming):
)

if custom_vjp:
estimate_backend = func.custom_vjp(estimate_backend, nondiff_argnums=(0,))
estimate_backend.defvjp(estimate_fwd, estimate_bwd) # type: ignore
return estimate_public
_estimate = func.custom_vjp(_estimate, nondiff_argnums=(0,))
_estimate.defvjp(estimate_fwd, estimate_bwd) # type: ignore
return estimate


def _hessenberg_forward(matvec, krylov_depth, v, *params, reortho: str):
Expand Down Expand Up @@ -486,7 +483,7 @@ def _extract_diag(x, offset=0):
return linalg.diagonal_matrix(diag, offset=offset)


def bidiag(Av: Callable, vA: Callable, depth, /, matrix_shape):
def bidiag(depth: int, /, matrix_shape):
"""Construct an implementation of **bidiagonalisation**.
Uses pre-allocation and full reorthogonalisation.
Expand All @@ -512,12 +509,16 @@ def bidiag(Av: Callable, vA: Callable, depth, /, matrix_shape):
msg3 = f"for a matrix with shape {matrix_shape}."
raise ValueError(msg1 + msg2 + msg3)

def estimate(v0, *parameters):
# todo: move the matvecs to the estimate() functions
# of tridiag and hessenberg. Then, update the SLQ functions
# then, give all methods here a materialise=True argument
# and simplify SLQ code massively.
def estimate(Av: Callable, vA: Callable, v0, *parameters):
v0_norm, length = _normalise(v0)
init_val = init(v0_norm)

def body_fun(_, s):
return step(s, *parameters)
return step(Av, vA, s, *parameters)

result = control_flow.fori_loop(
0, depth + 1, body_fun=body_fun, init_val=init_val
Expand All @@ -541,7 +542,7 @@ def init(init_vec: Array) -> State:
v0, _ = _normalise(init_vec)
return State(0, Us, Vs, alphas, betas, np.zeros(()), v0)

def step(state: State, *parameters) -> State:
def step(Av, vA, state: State, *parameters) -> State:
i, Us, Vs, alphas, betas, beta, vk = state
Vs = Vs.at[i].set(vk)
betas = betas.at[i].set(beta)
Expand Down
5 changes: 3 additions & 2 deletions matfree/eig.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from matfree.backend.typing import Array, Callable, Tuple


# todo: why does this function not return a callable?
def svd_partial(
v0: Array, depth: int, Av: Callable, vA: Callable, matrix_shape: Tuple[int, ...]
):
Expand All @@ -29,8 +30,8 @@ def svd_partial(
Shape of the matrix involved in matrix-vector and vector-matrix products.
"""
# Factorise the matrix
algorithm = decomp.bidiag(Av, vA, depth, matrix_shape=matrix_shape)
u, (d, e), vt, *_ = algorithm(v0)
algorithm = decomp.bidiag(depth, matrix_shape=matrix_shape)
u, (d, e), vt, *_ = algorithm(Av, vA, v0)

# Compute SVD of factorisation
B = _bidiagonal_dense(d, e)
Expand Down
51 changes: 25 additions & 26 deletions matfree/funm.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,14 @@
This includes matrix-function-vector products
$$
(f, v, p) \\mapsto f(A(p))v
(f, A, v, p) \\mapsto f(A(p))v
$$
as well as matrix-function extensions for stochastic trace estimation,
which provide
$$
(f, v, p) \\mapsto v^\\top f(A(p))v.
(f, A, v, p) \\mapsto v^\\top f(A(p))v.
$$
Plug these integrands into
Expand All @@ -31,9 +31,9 @@
>>>
>>> # Compute a matrix-logarithm with Lanczos' algorithm
>>> matfun = dense_funm_sym_eigh(jnp.log)
>>> tridiag = decomp.tridiag_sym(lambda s: A @ s, 4)
>>> tridiag = decomp.tridiag_sym(4)
>>> matfun_vec = funm_lanczos_sym(matfun, tridiag)
>>> matfun_vec(v)
>>> matfun_vec(lambda s: A @ s, v)
Array([-4.1, -1.3, -2.2, -2.1, -1.2, -3.3, -0.2, 0.3, 0.7, 0.9], dtype=float32)
"""

Expand Down Expand Up @@ -102,7 +102,7 @@ def _funm_polyexpand(matrix_poly_alg, /):
"""Compute a matrix-function-vector product via a polynomial expansion."""
(lower, upper), init_func, step_func, extract_func = matrix_poly_alg

def matvec(vec, *parameters):
def matrix_function_vector_product(vec, *parameters):
final_state = control_flow.fori_loop(
lower=lower,
upper=upper,
Expand All @@ -111,7 +111,7 @@ def matvec(vec, *parameters):
)
return extract_func(final_state)

return matvec
return matrix_function_vector_product


def funm_lanczos_sym(dense_funm: Callable, tridiag_sym: Callable, /) -> Callable:
Expand All @@ -133,10 +133,10 @@ def funm_lanczos_sym(dense_funm: Callable, tridiag_sym: Callable, /) -> Callable
[decomp.tridiag_sym][matfree.decomp.tridiag_sym].
"""

def estimate(vec, *parameters):
def estimate(matvec: Callable, vec, *parameters):
length = linalg.vector_norm(vec)
vec /= length
(basis, (diag, off_diag)), _ = tridiag_sym(vec, *parameters)
(basis, (diag, off_diag)), _ = tridiag_sym(matvec, vec, *parameters)
matrix = _todense_tridiag_sym(diag, off_diag)

funm = dense_funm(matrix)
Expand All @@ -146,24 +146,24 @@ def estimate(vec, *parameters):
return estimate


def integrand_funm_sym_logdet(order, matvec, /):
def integrand_funm_sym_logdet(order, /):
"""Construct the integrand for the log-determinant.
This function assumes a symmetric, positive definite matrix.
"""
return integrand_funm_sym(np.log, order, matvec)
return integrand_funm_sym(np.log, order)


def integrand_funm_sym(matfun, order, matvec, /):
def integrand_funm_sym(matfun, order, /):
"""Construct the integrand for matrix-function-trace estimation.
This function assumes a symmetric matrix.
"""
# todo: if we ask the user to flatten their matvecs,
# then we can give this code the same API as funm_lanczos_sym.
# Todo: expect these to be passed by the user.
dense_funm = dense_funm_sym_eigh(matfun)
algorithm = decomp.tridiag_sym(order)

def quadform(v0, *parameters):
def quadform(matvec, v0, *parameters):
v0_flat, v_unflatten = tree_util.ravel_pytree(v0)
length = linalg.vector_norm(v0_flat)
v0_flat /= length
Expand All @@ -174,8 +174,7 @@ def matvec_flat(v_flat, *p):
flat, unflatten = tree_util.ravel_pytree(Av)
return flat

algorithm = decomp.tridiag_sym(matvec_flat, order)
(_, (diag, off_diag)), _ = algorithm(v0_flat, *parameters)
(_, (diag, off_diag)), _ = algorithm(matvec_flat, v0_flat, *parameters)

dense = _todense_tridiag_sym(diag, off_diag)
fA = dense_funm(dense)
Expand All @@ -185,33 +184,34 @@ def matvec_flat(v_flat, *p):
return quadform


def integrand_funm_product_logdet(depth, matvec, vecmat, /):
def integrand_funm_product_logdet(depth, /):
r"""Construct the integrand for the log-determinant of a matrix-product.
Here, "product" refers to $X = A^\top A$.
"""
return integrand_funm_product(np.log, depth, matvec, vecmat)
return integrand_funm_product(np.log, depth)


def integrand_funm_product_schatten_norm(power, depth, matvec, vecmat, /):
def integrand_funm_product_schatten_norm(power, depth, /):
r"""Construct the integrand for the $p$-th power of the Schatten-p norm."""

def matfun(x):
"""Matrix-function for Schatten-p norms."""
return x ** (power / 2)

return integrand_funm_product(matfun, depth, matvec, vecmat)
return integrand_funm_product(matfun, depth)


def integrand_funm_product(matfun, depth, matvec, vecmat, /):
def integrand_funm_product(matfun, depth, /):
r"""Construct the integrand for matrix-function-trace estimation.
Instead of the trace of a function of a matrix,
compute the trace of a function of the product of matrices.
Here, "product" refers to $X = A^\top A$.
"""

def quadform(v0, *parameters):
def quadform(matvecs, v0, *parameters):
matvec, vecmat = matvecs
v0_flat, v_unflatten = tree_util.ravel_pytree(v0)
length = linalg.vector_norm(v0_flat)
v0_flat /= length
Expand All @@ -231,10 +231,9 @@ def vecmat_flat(w_flat):
return tree_util.ravel_pytree(wA)[0]

# Decompose into orthogonal-bidiag-orthogonal
algorithm = decomp.bidiag(
lambda v: matvec_flat(v)[0], vecmat_flat, depth, matrix_shape=matrix_shape
)
output = algorithm(v0_flat, *parameters)
algorithm = decomp.bidiag(depth, matrix_shape=matrix_shape)
matvec_flat_p = lambda v: matvec_flat(v)[0] # noqa: E731
output = algorithm(matvec_flat_p, vecmat_flat, v0_flat, *parameters)
u, (d, e), vt, *_ = output

# Compute SVD of factorisation
Expand Down
Loading

0 comments on commit fd99af5

Please sign in to comment.