Skip to content

Commit

Permalink
Create dedicated funm and funm_trace modules for functions of mat…
Browse files Browse the repository at this point in the history
…rices (#188)

* Create matfree/funm.py and matfree/funm_trace.py to collect matrix-function functionality

* Update the documentation
  • Loading branch information
pnkraemer authored May 27, 2024
1 parent 7622788 commit fdccbb4
Show file tree
Hide file tree
Showing 11 changed files with 235 additions and 195 deletions.
63 changes: 55 additions & 8 deletions matfree/polynomial.py → matfree/funm.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,29 @@
"""Approximate matrix-function-vector products with polynomial expansions.
This module does not include Lanczos-style approximations, which are
in [matfree.lanczos][matfree.lanczos].
"""Functions of matrices implemented as matrix-function-vector products.
Examples
--------
>>> import jax.random
>>> import jax.numpy as jnp
>>>
>>> jnp.set_printoptions(1)
>>>
>>> M = jax.random.normal(jax.random.PRNGKey(1), shape=(10, 10))
>>> A = M.T @ M
>>> v = jax.random.normal(jax.random.PRNGKey(2), shape=(10,))
>>>
>>> # Compute a matrix-logarithm with Lanczos' algorithm
>>> matfun_vec = funm_lanczos_spd(jnp.log, 4, lambda s: A @ s)
>>> matfun_vec(v)
Array([-4. , -2.1, -2.7, -1.9, -1.3, -3.5, -0.5, -0.1, 0.3, 1.5], dtype=float32)
"""

from matfree.backend import containers, control_flow, np
from matfree import lanczos
from matfree.backend import containers, control_flow, func, linalg, np
from matfree.backend.typing import Array


def funm_vector_product_chebyshev(matfun, order, matvec, /):
def funm_chebyshev(matfun, order, matvec, /):
"""Compute a matrix-function-vector product via Chebyshev's algorithm.
This function assumes that the **spectrum of the matrix-vector product
Expand Down Expand Up @@ -56,15 +71,15 @@ def extract_func(val: _ChebyshevState):
return val.interpolation

alg = (0, order - 1), init_func, recursion_func, extract_func
return _funm_vector_product_polyexpand(alg)
return _funm_polyexpand(alg)


def _chebyshev_nodes(n, /):
k = np.arange(n, step=1.0) + 1
return np.cos((2 * k - 1) / (2 * n) * np.pi())


def _funm_vector_product_polyexpand(matrix_poly_alg, /):
def _funm_polyexpand(matrix_poly_alg, /):
"""Implement a matrix-function-vector product via a polynomial expansion."""
(lower, upper), init_func, step_func, extract_func = matrix_poly_alg

Expand All @@ -78,3 +93,35 @@ def matvec(vec, *parameters):
return extract_func(final_state)

return matvec


def funm_lanczos_spd(matfun, order, matvec, /):
"""Implement a matrix-function-vector product via Lanczos' algorithm.
This algorithm uses Lanczos' tridiagonalisation with full re-orthogonalisation
and therefore applies only to symmetric, positive definite matrices.
"""
algorithm = lanczos.alg_tridiag_full_reortho(matvec, order)

def estimate(vec, *parameters):
length = linalg.vector_norm(vec)
vec /= length
basis, (diag, off_diag) = algorithm(vec, *parameters)
eigvals, eigvecs = _eigh_tridiag(diag, off_diag)

fx_eigvals = func.vmap(matfun)(eigvals)
return length * (basis.T @ (eigvecs @ (fx_eigvals * eigvecs[0, :])))

return estimate


def _eigh_tridiag(diag, off_diag):
# todo: once jax supports eigh_tridiagonal(eigvals_only=False),
# use it here. Until then: an eigen-decomposition of size (order + 1)
# does not hurt too much...
diag = linalg.diagonal_matrix(diag)
offdiag1 = linalg.diagonal_matrix(off_diag, -1)
offdiag2 = linalg.diagonal_matrix(off_diag, 1)
dense_matrix = diag + offdiag1 + offdiag2
eigvals, eigvecs = linalg.eigh(dense_matrix)
return eigvals, eigvecs
132 changes: 132 additions & 0 deletions matfree/funm_trace.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
"""Stochastic estimation of traces of functions of matrices.
This module extends [matfree.hutchinson][matfree.hutchinson].
"""

from matfree import lanczos
from matfree.backend import func, linalg, np, tree_util

# todo: currently, all dense matrix-functions are computed
# via eigh(). But for e.g. log and exp, we might want to do
# something else.


def integrand_spd_logdet(order, matvec, /):
"""Construct the integrand for the log-determinant.
This function assumes a symmetric, positive definite matrix.
"""
return integrand_spd(np.log, order, matvec)


def integrand_spd(matfun, order, matvec, /):
"""Quadratic form for stochastic Lanczos quadrature.
This function assumes a symmetric, positive definite matrix.
"""

def quadform(v0, *parameters):
v0_flat, v_unflatten = tree_util.ravel_pytree(v0)
length = linalg.vector_norm(v0_flat)
v0_flat /= length

def matvec_flat(v_flat, *p):
v = v_unflatten(v_flat)
Av = matvec(v, *p)
flat, unflatten = tree_util.ravel_pytree(Av)
return flat

algorithm = lanczos.alg_tridiag_full_reortho(matvec_flat, order)
_, (diag, off_diag) = algorithm(v0_flat, *parameters)
eigvals, eigvecs = _eigh_tridiag(diag, off_diag)

# Since Q orthogonal (orthonormal) to v0, Q v = Q[0],
# and therefore (Q v)^T f(D) (Qv) = Q[0] * f(diag) * Q[0]
fx_eigvals = func.vmap(matfun)(eigvals)
return length**2 * linalg.vecdot(eigvecs[0, :], fx_eigvals * eigvecs[0, :])

return quadform


def integrand_product_logdet(depth, matvec, vecmat, /):
r"""Construct the integrand for the log-determinant of a matrix-product.
Here, "product" refers to $X = A^\top A$.
"""
return integrand_product(np.log, depth, matvec, vecmat)


def integrand_product_schatten_norm(power, depth, matvec, vecmat, /):
r"""Construct the integrand for the p-th power of the Schatten-p norm."""

def matfun(x):
"""Matrix-function for Schatten-p norms."""
return x ** (power / 2)

return integrand_product(matfun, depth, matvec, vecmat)


def integrand_product(matfun, depth, matvec, vecmat, /):
r"""Construct the integrand for the trace of a function of a matrix-product.
Instead of the trace of a function of a matrix,
compute the trace of a function of the product of matrices.
Here, "product" refers to $X = A^\top A$.
"""

def quadform(v0, *parameters):
v0_flat, v_unflatten = tree_util.ravel_pytree(v0)
length = linalg.vector_norm(v0_flat)
v0_flat /= length

def matvec_flat(v_flat, *p):
v = v_unflatten(v_flat)
Av = matvec(v, *p)
flat, unflatten = tree_util.ravel_pytree(Av)
return flat, tree_util.partial_pytree(unflatten)

w0_flat, w_unflatten = func.eval_shape(matvec_flat, v0_flat)
matrix_shape = (*np.shape(w0_flat), *np.shape(v0_flat))

def vecmat_flat(w_flat):
w = w_unflatten(w_flat)
wA = vecmat(w, *parameters)
return tree_util.ravel_pytree(wA)[0]

# Decompose into orthogonal-bidiag-orthogonal
algorithm = lanczos.alg_bidiag_full_reortho(
lambda v: matvec_flat(v)[0], vecmat_flat, depth, matrix_shape=matrix_shape
)
output = algorithm(v0_flat, *parameters)
u, (d, e), vt, _ = output

# Compute SVD of factorisation
B = _bidiagonal_dense(d, e)
_, S, Vt = linalg.svd(B, full_matrices=False)

# Since Q orthogonal (orthonormal) to v0, Q v = Q[0],
# and therefore (Q v)^T f(D) (Qv) = Q[0] * f(diag) * Q[0]
eigvals, eigvecs = S**2, Vt.T
fx_eigvals = func.vmap(matfun)(eigvals)
return length**2 * linalg.vecdot(eigvecs[0, :], fx_eigvals * eigvecs[0, :])

return quadform


def _bidiagonal_dense(d, e):
diag = linalg.diagonal_matrix(d)
offdiag = linalg.diagonal_matrix(e, 1)
return diag + offdiag


def _eigh_tridiag(diag, off_diag):
# todo: once jax supports eigh_tridiagonal(eigvals_only=False),
# use it here. Until then: an eigen-decomposition of size (order + 1)
# does not hurt too much...
diag = linalg.diagonal_matrix(diag)
offdiag1 = linalg.diagonal_matrix(off_diag, -1)
offdiag2 = linalg.diagonal_matrix(off_diag, 1)
dense_matrix = diag + offdiag1 + offdiag2
eigvals, eigvecs = linalg.eigh(dense_matrix)
return eigvals, eigvecs
2 changes: 1 addition & 1 deletion matfree/hutchinson.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
"""Hutchinson-style estimation."""
"""Stochastic estimation of traces, diagonals, and more."""

from matfree.backend import func, linalg, np, prng, tree_util

Expand Down
Loading

0 comments on commit fdccbb4

Please sign in to comment.