Skip to content

Commit

Permalink
Add a 'materialize' option to tridiag() and bidiag() to simplify down…
Browse files Browse the repository at this point in the history
…stream applications (#204)
  • Loading branch information
pnkraemer authored May 31, 2024
1 parent fd99af5 commit fddc9a7
Show file tree
Hide file tree
Showing 7 changed files with 74 additions and 65 deletions.
68 changes: 54 additions & 14 deletions matfree/decomp.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,14 @@
from matfree.backend.typing import Array, Callable


def tridiag_sym(krylov_depth, /, *, reortho: str = "full", custom_vjp: bool = True):
def tridiag_sym(
krylov_depth,
/,
*,
materialize: bool = True,
reortho: str = "full",
custom_vjp: bool = True,
):
r"""Construct an implementation of **tridiagonalisation**.
Uses pre-allocation, and full reorthogonalisation if `reortho` is set to `"full"`.
Expand Down Expand Up @@ -44,43 +51,70 @@ def tridiag_sym(krylov_depth, /, *, reortho: str = "full", custom_vjp: bool = Tr
"""

if reortho == "full":
return _tridiag_reortho_full(krylov_depth, custom_vjp=custom_vjp)
return _tridiag_reortho_full(
krylov_depth, custom_vjp=custom_vjp, materialize=materialize
)
if reortho == "none":
return _tridiag_reortho_none(krylov_depth, custom_vjp=custom_vjp)
return _tridiag_reortho_none(
krylov_depth, custom_vjp=custom_vjp, materialize=materialize
)

msg = f"reortho={reortho} unsupported. Choose eiter {'full', 'none'}."
raise ValueError(msg)


def _tridiag_reortho_full(krylov_depth, /, *, custom_vjp):
def _tridiag_reortho_full(krylov_depth: int, /, *, custom_vjp: bool, materialize: bool):
# Implement via Arnoldi to use the reorthogonalised adjoints.
# The complexity difference is minimal with full reortho.
alg = hessenberg(krylov_depth, custom_vjp=custom_vjp, reortho="full")

def estimate(matvec, vec, *params):
Q, H, v, _norm = alg(matvec, vec, *params)

remainder = (v / linalg.vector_norm(v), linalg.vector_norm(v))

T = 0.5 * (H + H.T)
diags = linalg.diagonal(T, offset=0)
offdiags = linalg.diagonal(T, offset=1)

if materialize:
matrix = _todense_tridiag_sym(diags, offdiags)
decomposition = (Q.T, matrix)
return decomposition, remainder

decomposition = (Q.T, (diags, offdiags))
remainder = (v / linalg.vector_norm(v), linalg.vector_norm(v))
return decomposition, remainder

return estimate


def _tridiag_reortho_none(krylov_depth, /, *, custom_vjp):
def _todense_tridiag_sym(diag, off_diag):
diag = linalg.diagonal_matrix(diag)
offdiag1 = linalg.diagonal_matrix(off_diag, -1)
offdiag2 = linalg.diagonal_matrix(off_diag, 1)
return diag + offdiag1 + offdiag2


def _tridiag_reortho_none(krylov_depth: int, /, *, custom_vjp: bool, materialize: bool):
def estimate(matvec, vec, *params):
(Q, H), v = _estimate(matvec, vec, *params)

if materialize:
H = _todense_tridiag_sym(*H)
return (Q, H), v

def _estimate(matvec, vec, *params):
*values, _ = _tridiag_forward(matvec, krylov_depth, vec, *params)
return values

def estimate_fwd(matvec, vec, *params):
value = estimate(matvec, vec, *params)
value = _estimate(matvec, vec, *params)
return value, (value, (linalg.vector_norm(vec), *params))

def estimate_bwd(matvec, cache, vjp_incoming):
# Read incoming gradients and stack related quantities
print(tree_util.tree_map(np.shape, vjp_incoming))

(dxs, (dalphas, dbetas)), (dx_last, dbeta_last) = vjp_incoming
dxs = np.concatenate((dxs, dx_last[None]))
dbetas = np.concatenate((dbetas, dbeta_last[None]))
Expand All @@ -105,8 +139,8 @@ def estimate_bwd(matvec, cache, vjp_incoming):
return grads

if custom_vjp:
estimate = func.custom_vjp(estimate, nondiff_argnums=[0])
estimate.defvjp(estimate_fwd, estimate_bwd) # type: ignore
_estimate = func.custom_vjp(_estimate, nondiff_argnums=[0])
_estimate.defvjp(estimate_fwd, estimate_bwd) # type: ignore

return estimate

Expand Down Expand Up @@ -483,7 +517,7 @@ def _extract_diag(x, offset=0):
return linalg.diagonal_matrix(diag, offset=offset)


def bidiag(depth: int, /, matrix_shape):
def bidiag(depth: int, /, matrix_shape, materialize: bool = True):
"""Construct an implementation of **bidiagonalisation**.
Uses pre-allocation and full reorthogonalisation.
Expand All @@ -509,10 +543,6 @@ def bidiag(depth: int, /, matrix_shape):
msg3 = f"for a matrix with shape {matrix_shape}."
raise ValueError(msg1 + msg2 + msg3)

# todo: move the matvecs to the estimate() functions
# of tridiag and hessenberg. Then, update the SLQ functions
# then, give all methods here a materialise=True argument
# and simplify SLQ code massively.
def estimate(Av: Callable, vA: Callable, v0, *parameters):
v0_norm, length = _normalise(v0)
init_val = init(v0_norm)
Expand Down Expand Up @@ -561,6 +591,11 @@ def step(Av, vA, state: State, *parameters) -> State:

def extract(state: State, /):
_, uk_all, vk_all, alphas, betas, beta, vk = state

if materialize:
B = _todense_bidiag(alphas, betas[1:])
return uk_all.T, B, vk_all, (beta, vk)

return uk_all.T, (alphas, betas[1:]), vk_all, (beta, vk)

def _gram_schmidt_classical(vec, vectors): # Gram-Schmidt
Expand All @@ -577,4 +612,9 @@ def _normalise(vec):
length = linalg.vector_norm(vec)
return vec / length, length

def _todense_bidiag(d, e):
diag = linalg.diagonal_matrix(d)
offdiag = linalg.diagonal_matrix(e, 1)
return diag + offdiag

return estimate
5 changes: 2 additions & 3 deletions matfree/eig.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,10 @@ def svd_partial(
Shape of the matrix involved in matrix-vector and vector-matrix products.
"""
# Factorise the matrix
algorithm = decomp.bidiag(depth, matrix_shape=matrix_shape)
u, (d, e), vt, *_ = algorithm(Av, vA, v0)
algorithm = decomp.bidiag(depth, matrix_shape=matrix_shape, materialize=True)
u, B, vt, *_ = algorithm(Av, vA, v0)

# Compute SVD of factorisation
B = _bidiagonal_dense(d, e)
U, S, Vt = linalg.svd(B, full_matrices=False)

# Combine orthogonal transformations
Expand Down
33 changes: 6 additions & 27 deletions matfree/funm.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,8 +136,8 @@ def funm_lanczos_sym(dense_funm: Callable, tridiag_sym: Callable, /) -> Callable
def estimate(matvec: Callable, vec, *parameters):
length = linalg.vector_norm(vec)
vec /= length
(basis, (diag, off_diag)), _ = tridiag_sym(matvec, vec, *parameters)
matrix = _todense_tridiag_sym(diag, off_diag)
(basis, matrix), _ = tridiag_sym(matvec, vec, *parameters)
# matrix = _todense_tridiag_sym(diag, off_diag)

funm = dense_funm(matrix)
e1 = np.eye(len(matrix))[0, :]
Expand All @@ -161,7 +161,7 @@ def integrand_funm_sym(matfun, order, /):
"""
# Todo: expect these to be passed by the user.
dense_funm = dense_funm_sym_eigh(matfun)
algorithm = decomp.tridiag_sym(order)
algorithm = decomp.tridiag_sym(order, materialize=True)

def quadform(matvec, v0, *parameters):
v0_flat, v_unflatten = tree_util.ravel_pytree(v0)
Expand All @@ -174,9 +174,8 @@ def matvec_flat(v_flat, *p):
flat, unflatten = tree_util.ravel_pytree(Av)
return flat

(_, (diag, off_diag)), _ = algorithm(matvec_flat, v0_flat, *parameters)
(_, dense), _ = algorithm(matvec_flat, v0_flat, *parameters)

dense = _todense_tridiag_sym(diag, off_diag)
fA = dense_funm(dense)
e1 = np.eye(len(fA))[0, :]
return length**2 * linalg.inner(e1, fA @ e1)
Expand Down Expand Up @@ -224,21 +223,19 @@ def matvec_flat(v_flat, *p):

w0_flat, w_unflatten = func.eval_shape(matvec_flat, v0_flat)
matrix_shape = (*np.shape(w0_flat), *np.shape(v0_flat))
algorithm = decomp.bidiag(depth, matrix_shape=matrix_shape, materialize=True)

def vecmat_flat(w_flat):
w = w_unflatten(w_flat)
wA = vecmat(w, *parameters)
return tree_util.ravel_pytree(wA)[0]

# Decompose into orthogonal-bidiag-orthogonal
algorithm = decomp.bidiag(depth, matrix_shape=matrix_shape)
matvec_flat_p = lambda v: matvec_flat(v)[0] # noqa: E731
output = algorithm(matvec_flat_p, vecmat_flat, v0_flat, *parameters)
u, (d, e), vt, *_ = output
u, B, vt, *_ = output

# Compute SVD of factorisation
B = _todense_bidiag(d, e)

# todo: turn the following lines into dense_funm_svd()
_, S, Vt = linalg.svd(B, full_matrices=False)

Expand Down Expand Up @@ -277,21 +274,3 @@ def fun(dense_matrix):
return linalg.funm_schur(dense_matrix, matfun)

return fun


# todo: if we move this logic to the decomposition algorithms
# (e.g. with a materalize=True flag in the decomposition construction),
# then all trace_of_funm implementation reduce to very few lines.


def _todense_tridiag_sym(diag, off_diag):
diag = linalg.diagonal_matrix(diag)
offdiag1 = linalg.diagonal_matrix(off_diag, -1)
offdiag2 = linalg.diagonal_matrix(off_diag, 1)
return diag + offdiag1 + offdiag2


def _todense_bidiag(d, e):
diag = linalg.diagonal_matrix(d)
offdiag = linalg.diagonal_matrix(e, 1)
return diag + offdiag
8 changes: 4 additions & 4 deletions tests/test_decomp/test_bidiag.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def Av(v):
def vA(v):
return v @ A

algorithm = decomp.bidiag(order, matrix_shape=np.shape(A))
algorithm = decomp.bidiag(order, matrix_shape=np.shape(A), materialize=False)
Us, Bs, Vs, (b, v), ln = algorithm(Av, vA, v0)
(d_m, e_m) = Bs

Expand Down Expand Up @@ -70,7 +70,7 @@ def test_error_too_high_depth(A):
max_depth = min(nrows, ncols) - 1

with testing.raises(ValueError, match=""):
_ = decomp.bidiag(max_depth + 1, matrix_shape=np.shape(A))
_ = decomp.bidiag(max_depth + 1, matrix_shape=np.shape(A), materialize=False)


@testing.parametrize("nrows", [5])
Expand All @@ -80,7 +80,7 @@ def test_error_too_low_depth(A):
"""Assert that a ValueError is raised when the depth is negative."""
min_depth = 0
with testing.raises(ValueError, match=""):
_ = decomp.bidiag(min_depth - 1, matrix_shape=np.shape(A))
_ = decomp.bidiag(min_depth - 1, matrix_shape=np.shape(A), materialize=False)


@testing.parametrize("nrows", [15])
Expand All @@ -98,7 +98,7 @@ def Av(v):
def vA(v):
return v @ A

algorithm = decomp.bidiag(0, matrix_shape=np.shape(A))
algorithm = decomp.bidiag(0, matrix_shape=np.shape(A), materialize=False)
Us, Bs, Vs, (b, v), ln = algorithm(Av, vA, v0)
(d_m, e_m) = Bs
assert np.shape(Us) == (nrows, 1)
Expand Down
18 changes: 4 additions & 14 deletions tests/test_decomp/test_tridiag_sym.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,10 @@ def test_full_rank_reconstruction_is_exact(reortho, ndim):
vector = np.flip(np.arange(1.0, 1.0 + len(eigvals)))

# Run Lanczos approximation
algorithm = decomp.tridiag_sym(ndim, reortho=reortho)
(lanczos_vectors, tridiag), _ = algorithm(lambda s, p: p @ s, vector, matrix)
algorithm = decomp.tridiag_sym(ndim, reortho=reortho, materialize=True)
(lanczos_vectors, dense_matrix), _ = algorithm(lambda s, p: p @ s, vector, matrix)

# Reconstruct the original matrix from the full-order approximation
dense_matrix = _dense_tridiag_sym(*tridiag)
matrix_reconstructed = lanczos_vectors.T @ dense_matrix @ lanczos_vectors

if reortho == "full":
Expand Down Expand Up @@ -46,19 +45,10 @@ def test_mid_rank_reconstruction_satisfies_decomposition(ndim, krylov_depth, reo
vector = np.flip(np.arange(1.0, 1.0 + len(eigvals)))

# Run Lanczos approximation
algorithm = decomp.tridiag_sym(krylov_depth, reortho=reortho)
(lanczos_vectors, tridiag), (q, b) = algorithm(lambda s, p: p @ s, vector, matrix)
algorithm = decomp.tridiag_sym(krylov_depth, reortho=reortho, materialize=True)
(Q, T), (q, b) = algorithm(lambda s, p: p @ s, vector, matrix)

# Verify the decomposition
Q, T = lanczos_vectors, _dense_tridiag_sym(*tridiag)
tols = {"atol": 1e-5, "rtol": 1e-5}
e_K = np.eye(krylov_depth)[-1]
assert np.allclose(matrix @ Q.T, Q.T @ T + linalg.outer(e_K, q * b).T, **tols)


def _dense_tridiag_sym(diagonal, off_diagonal):
return (
linalg.diagonal_matrix(diagonal)
+ linalg.diagonal_matrix(off_diagonal, 1)
+ linalg.diagonal_matrix(off_diagonal, -1)
)
2 changes: 1 addition & 1 deletion tests/test_decomp/test_tridiag_sym_adjoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def matvec(s, p):

# Construct a vector-to-vector decomposition function
def decompose(f, *, custom_vjp):
kwargs = {"reortho": reortho, "custom_vjp": custom_vjp}
kwargs = {"reortho": reortho, "custom_vjp": custom_vjp, "materialize": False}
algorithm = decomp.tridiag_sym(krylov_order, **kwargs)
output = algorithm(matvec, *unflatten(f))
return tree_util.ravel_pytree(output)[0]
Expand Down
5 changes: 3 additions & 2 deletions tests/test_funm/test_funm_lanczos_sym.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@


@testing.parametrize("dense_funm", [funm.dense_funm_sym_eigh, funm.dense_funm_schur])
def test_funm_lanczos_sym_matches_eigh_implementation(dense_funm, n=11):
@testing.parametrize("reortho", ["full", "none"])
def test_funm_lanczos_sym_matches_eigh_implementation(dense_funm, reortho, n=11):
"""Test matrix-function-vector products via Lanczos' algorithm."""
# Create a test-problem: matvec, matrix function,
# vector, and parameters (a matrix).
Expand All @@ -28,7 +29,7 @@ def fun(x):

# Compute the matrix-function vector product
dense_funm = dense_funm(fun)
lanczos = decomp.tridiag_sym(6)
lanczos = decomp.tridiag_sym(6, materialize=True, reortho=reortho)
matfun_vec = funm.funm_lanczos_sym(dense_funm, lanczos)
received = matfun_vec(matvec, v, matrix)
assert np.allclose(expected, received, atol=1e-6)

0 comments on commit fddc9a7

Please sign in to comment.