Skip to content

Commit

Permalink
Rename lanczos.py to decomp.py (#191)
Browse files Browse the repository at this point in the history
* Rename lanczos.py to decomp.py

* Rename decomp.alg_* functions

* Remove _full_reortho from function names

* Rename tridiag( to tridiag_sym(

* Delete unused functions
  • Loading branch information
pnkraemer authored May 27, 2024
1 parent f51cf6d commit cb86c32
Show file tree
Hide file tree
Showing 11 changed files with 49 additions and 83 deletions.
4 changes: 0 additions & 4 deletions matfree/backend/control_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,3 @@ def fori_loop(lower, upper, body_fun, init_val):

def while_loop(cond_fun, body_fun, init_val):
return jax.lax.while_loop(cond_fun, body_fun, init_val)


def array_map(fun, /, xs):
return jax.lax.map(fun, xs)
12 changes: 0 additions & 12 deletions matfree/backend/np.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,14 +122,6 @@ def sum(x, /, axis=None): # noqa: A001
return jnp.sum(x, axis)


def array_min(x, /):
return jnp.amin(x)


def array_max(x, /, axis=None):
return jnp.amax(x, axis=axis)


def argmax(x, /, axis=None):
return jnp.argmax(x, axis=axis)

Expand All @@ -138,10 +130,6 @@ def argsort(x, /):
return jnp.argsort(x)


def elementwise_max(x1, x2, /):
return jnp.maximum(x1, x2)


def nanmean(x, /, axis=None):
return jnp.nanmean(x, axis)

Expand Down
7 changes: 0 additions & 7 deletions matfree/backend/progressbar.py

This file was deleted.

24 changes: 10 additions & 14 deletions matfree/lanczos.py → matfree/decomp.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""Lanczos-style matrix decompositions.
"""Matrix-free matrix decompositions.
This module includes various Lanczos-decompositions of matrices
(tridiagonalisation, bidiagonalisation, etc.).
(tri-diagonal, bi-diagonal, etc.).
For stochastic Lanczos quadrature, see
[matfree.stochtrace_funm][matfree.stochtrace_funm].
Expand All @@ -12,10 +12,6 @@
from matfree.backend import containers, control_flow, func, linalg, np
from matfree.backend.typing import Array, Callable, Tuple

# todo: rename this module, because we may easily include arnoldi here, too.
# what do we rename it to? krylov.py? decomp.py? krylovbasis.py?


# todo: rename svd_approx to svd_partial() because the algorithm is called
# "Partial SVD", not "Approximate SVD".

Expand Down Expand Up @@ -44,7 +40,7 @@ def svd_approx(
Shape of the matrix involved in matrix-vector and vector-matrix products.
"""
# Factorise the matrix
algorithm = alg_bidiag_full_reortho(Av, vA, depth, matrix_shape=matrix_shape)
algorithm = bidiag(Av, vA, depth, matrix_shape=matrix_shape)
u, (d, e), vt, _ = algorithm(v0)

# Compute SVD of factorisation
Expand All @@ -71,12 +67,10 @@ class _LanczosAlg(containers.NamedTuple):
"""Range of the for-loop used to decompose a matrix."""


def alg_tridiag_full_reortho(
Av: Callable, depth, /, validate_unit_2_norm=False
) -> Callable:
def tridiag_sym(Av: Callable, depth, /, validate_unit_2_norm=False) -> Callable:
"""Construct an implementation of **tridiagonalisation**.
Uses pre-allocation. Fully reorthogonalise vectors at every step.
Uses pre-allocation and full reorthogonalisation.
This algorithm assumes a **symmetric matrix**.
Expand Down Expand Up @@ -144,12 +138,12 @@ def extract(state: State, /):
return func.partial(_decompose_fori_loop, algorithm=alg)


def alg_bidiag_full_reortho(
def bidiag(
Av: Callable, vA: Callable, depth, /, matrix_shape, validate_unit_2_norm=False
):
"""Construct an implementation of **bidiagonalisation**.
Uses pre-allocation. Fully reorthogonalise vectors at every step.
Uses pre-allocation and full reorthogonalisation.
Works for **arbitrary matrices**. No symmetry required.
Expand Down Expand Up @@ -212,6 +206,8 @@ def extract(state: State, /):


def _validate_unit_2_norm(v, /):
# todo: replace this functionality with normalising internally.
#
# Lanczos assumes a unit-2-norm vector as an input
# We cannot raise an error based on values of the init_vec,
# but we can make it obvious that the result is unusable.
Expand Down Expand Up @@ -270,7 +266,7 @@ def _bidiagonal_dense(d, e):
return diag + offdiag


def _eigh_tridiag(diag, off_diag):
def _eigh_tridiag_sym(diag, off_diag):
# todo: once jax supports eigh_tridiagonal(eigvals_only=False),
# use it here. Until then: an eigen-decomposition of size (order + 1)
# does not hurt too much...
Expand Down
8 changes: 4 additions & 4 deletions matfree/funm.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
Array([-4. , -2.1, -2.7, -1.9, -1.3, -3.5, -0.5, -0.1, 0.3, 1.5], dtype=float32)
"""

from matfree import lanczos
from matfree import decomp
from matfree.backend import containers, control_flow, func, linalg, np
from matfree.backend.typing import Array, Callable

Expand Down Expand Up @@ -101,21 +101,21 @@ def funm_lanczos_sym(matfun: Callable, order: int, matvec: Callable, /) -> Calla
This algorithm uses Lanczos' tridiagonalisation
and therefore applies only to symmetric matrices.
"""
algorithm = lanczos.alg_tridiag_full_reortho(matvec, order)
algorithm = decomp.tridiag_sym(matvec, order)

def estimate(vec, *parameters):
length = linalg.vector_norm(vec)
vec /= length
basis, (diag, off_diag) = algorithm(vec, *parameters)
eigvals, eigvecs = _eigh_tridiag(diag, off_diag)
eigvals, eigvecs = _eigh_tridiag_sym(diag, off_diag)

fx_eigvals = func.vmap(matfun)(eigvals)
return length * (basis.T @ (eigvecs @ (fx_eigvals * eigvecs[0, :])))

return estimate


def _eigh_tridiag(diag, off_diag):
def _eigh_tridiag_sym(diag, off_diag):
# todo: once jax supports eigh_tridiagonal(eigvals_only=False),
# use it here. Until then: an eigen-decomposition of size (order + 1)
# does not hurt too much...
Expand Down
10 changes: 5 additions & 5 deletions matfree/stochtrace_funm.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
"""

from matfree import lanczos
from matfree import decomp
from matfree.backend import func, linalg, np, tree_util

# todo: currently, all dense matrix-functions are computed
Expand Down Expand Up @@ -37,9 +37,9 @@ def matvec_flat(v_flat, *p):
flat, unflatten = tree_util.ravel_pytree(Av)
return flat

algorithm = lanczos.alg_tridiag_full_reortho(matvec_flat, order)
algorithm = decomp.tridiag_sym(matvec_flat, order)
_, (diag, off_diag) = algorithm(v0_flat, *parameters)
eigvals, eigvecs = _eigh_tridiag(diag, off_diag)
eigvals, eigvecs = _eigh_tridiag_sym(diag, off_diag)

# Since Q orthogonal (orthonormal) to v0, Q v = Q[0],
# and therefore (Q v)^T f(D) (Qv) = Q[0] * f(diag) * Q[0]
Expand Down Expand Up @@ -95,7 +95,7 @@ def vecmat_flat(w_flat):
return tree_util.ravel_pytree(wA)[0]

# Decompose into orthogonal-bidiag-orthogonal
algorithm = lanczos.alg_bidiag_full_reortho(
algorithm = decomp.bidiag(
lambda v: matvec_flat(v)[0], vecmat_flat, depth, matrix_shape=matrix_shape
)
output = algorithm(v0_flat, *parameters)
Expand All @@ -120,7 +120,7 @@ def _bidiagonal_dense(d, e):
return diag + offdiag


def _eigh_tridiag(diag, off_diag):
def _eigh_tridiag_sym(diag, off_diag):
# todo: once jax supports eigh_tridiagonal(eigvals_only=False),
# use it here. Until then: an eigen-decomposition of size (order + 1)
# does not hurt too much...
Expand Down
15 changes: 14 additions & 1 deletion matfree/test_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ def symmetric_matrix_from_eigenvalues(eigvals, /):

# QR decompose. We need the orthogonal matrix.
# Treat Q as a stack of eigenvectors.
Q, R = linalg.qr_reduced(X)
Q, _R = linalg.qr_reduced(X)

# Treat Q as eigenvectors, and 'D' as eigenvalues.
# return Q D Q.T.
Expand All @@ -28,3 +28,16 @@ def asymmetric_matrix_from_singular_values(vals, /, nrows, ncols):
A /= nrows * ncols
U, S, Vt = linalg.svd(A, full_matrices=False)
return U @ linalg.diagonal(vals) @ Vt


def to_dense_bidiag(d, e, /, offset=1):
diag = linalg.diagonal_matrix(d)
offdiag = linalg.diagonal_matrix(e, offset=offset)
return diag + offdiag


def to_dense_tridiag_sym(d, e, /):
diag = linalg.diagonal_matrix(d)
offdiag1 = linalg.diagonal_matrix(e, offset=1)
offdiag2 = linalg.diagonal_matrix(e, offset=-1)
return diag + offdiag1 + offdiag2
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""Test the Golub-Kahan-Lanczos bi-diagonalisation with full re-orthogonalisation."""

from matfree import lanczos, test_util
from matfree import decomp, test_util
from matfree.backend import linalg, np, prng, testing


Expand All @@ -18,7 +18,7 @@ def A(nrows, ncols, num_significant_singular_vals):
@testing.parametrize("ncols", [49])
@testing.parametrize("num_significant_singular_vals", [4])
@testing.parametrize("order", [6]) # ~1.5 * num_significant_eigvals
def test_lanczos_bidiag_full_reortho(A, order):
def test_bidiag(A, order):
"""Test that Lanczos tridiagonalisation yields an orthogonal-tridiagonal decomp."""
nrows, ncols = np.shape(A)
key = prng.prng_key(1)
Expand All @@ -30,7 +30,7 @@ def Av(v):
def vA(v):
return v @ A

algorithm = lanczos.alg_bidiag_full_reortho(Av, vA, order, matrix_shape=np.shape(A))
algorithm = decomp.bidiag(Av, vA, order, matrix_shape=np.shape(A))
v0 /= linalg.vector_norm(v0)
Us, Bs, Vs, (b, v) = algorithm(v0)
(d_m, e_m) = Bs
Expand All @@ -47,7 +47,7 @@ def vA(v):
assert np.allclose(linalg.diagonal(UAVt), d_m, **tols_decomp)
assert np.allclose(linalg.diagonal(UAVt, 1), e_m, **tols_decomp)

B = _bidiagonal_dense(d_m, e_m)
B = test_util.to_dense_bidiag(d_m, e_m)
assert np.shape(B) == (order + 1, order + 1)
assert np.allclose(UAVt, B, **tols_decomp)

Expand All @@ -60,12 +60,6 @@ def vA(v):
assert np.allclose(AtUt, VtBtb_plus_bve, **tols_decomp)


def _bidiagonal_dense(d, e):
diag = linalg.diagonal_matrix(d)
offdiag = linalg.diagonal_matrix(e, 1)
return diag + offdiag


@testing.parametrize("nrows", [5])
@testing.parametrize("ncols", [3])
@testing.parametrize("num_significant_singular_vals", [3])
Expand All @@ -79,9 +73,7 @@ def test_error_too_high_depth(A):
def eye(v):
return v

_ = lanczos.alg_bidiag_full_reortho(
eye, eye, max_depth + 1, matrix_shape=np.shape(A)
)
_ = decomp.bidiag(eye, eye, max_depth + 1, matrix_shape=np.shape(A))


@testing.parametrize("nrows", [5])
Expand All @@ -95,9 +87,7 @@ def test_error_too_low_depth(A):
def eye(v):
return v

_ = lanczos.alg_bidiag_full_reortho(
eye, eye, min_depth - 1, matrix_shape=np.shape(A)
)
_ = decomp.bidiag(eye, eye, min_depth - 1, matrix_shape=np.shape(A))


@testing.parametrize("nrows", [15])
Expand All @@ -115,7 +105,7 @@ def Av(v):
def vA(v):
return v @ A

algorithm = lanczos.alg_bidiag_full_reortho(Av, vA, 0, matrix_shape=np.shape(A))
algorithm = decomp.bidiag(Av, vA, 0, matrix_shape=np.shape(A))
Us, Bs, Vs, (b, v) = algorithm(v0)
(d_m, e_m) = Bs
assert np.shape(Us) == (nrows, 1)
Expand Down Expand Up @@ -144,7 +134,7 @@ def Av(v):
def vA(v):
return v @ A

algorithm = lanczos.alg_bidiag_full_reortho(
algorithm = decomp.bidiag(
Av, vA, order, matrix_shape=np.shape(A), validate_unit_2_norm=True
)
Us, (d_m, e_m), Vs, (b, v) = algorithm(v0)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""Tests for SVD functionality."""

from matfree import lanczos, test_util
from matfree import decomp, test_util
from matfree.backend import linalg, np, testing


Expand Down Expand Up @@ -34,7 +34,7 @@ def vA(v):

v0 = np.ones((ncols,))
v0 /= linalg.vector_norm(v0)
U, S, Vt = lanczos.svd_approx(v0, depth, Av, vA, matrix_shape=np.shape(A))
U, S, Vt = decomp.svd_approx(v0, depth, Av, vA, matrix_shape=np.shape(A))
U_, S_, Vt_ = linalg.svd(A, full_matrices=False)

tols_decomp = {"atol": 1e-5, "rtol": 1e-5}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""Test the Lanczos tri-diagonalisation with full re-orthogonalisation."""

from matfree import lanczos, test_util
from matfree import decomp, test_util
from matfree.backend import linalg, np, prng, testing


Expand All @@ -23,7 +23,7 @@ def test_max_order(A):
key = prng.prng_key(1)
v0 = prng.normal(key, shape=(n,))
v0 /= linalg.vector_norm(v0)
algorithm = lanczos.alg_tridiag_full_reortho(lambda v: A @ v, order)
algorithm = decomp.tridiag_sym(lambda v: A @ v, order)
Q, (d_m, e_m) = algorithm(v0)

# Lanczos is not stable.
Expand All @@ -35,7 +35,7 @@ def test_max_order(A):
assert np.allclose(Q.T @ Q, np.eye(n), **tols_decomp), Q.T @ Q

# T = Q A Qt
T = _sym_tridiagonal_dense(d_m, e_m)
T = test_util.to_dense_tridiag_sym(d_m, e_m)
QAQt = Q @ A @ Q.T
assert np.shape(T) == (order + 1, order + 1)

Expand Down Expand Up @@ -65,7 +65,7 @@ def test_identity(A, order):
key = prng.prng_key(1)
v0 = prng.normal(key, shape=(n,))
v0 /= linalg.vector_norm(v0)
algorithm = lanczos.alg_tridiag_full_reortho(lambda v: A @ v, order)
algorithm = decomp.tridiag_sym(lambda v: A @ v, order)
Q, tridiag = algorithm(v0)
(d_m, e_m) = tridiag

Expand All @@ -76,7 +76,7 @@ def test_identity(A, order):
assert np.allclose(Q @ Q.T, np.eye(order + 1), **tols_decomp), Q @ Q.T

# T = Q A Qt
T = _sym_tridiagonal_dense(d_m, e_m)
T = test_util.to_dense_tridiag_sym(d_m, e_m)
QAQt = Q @ A @ Q.T
assert np.shape(T) == (order + 1, order + 1)

Expand All @@ -89,13 +89,6 @@ def test_identity(A, order):
assert np.allclose(QAQt, T, **tols_decomp)


def _sym_tridiagonal_dense(d, e):
diag = linalg.diagonal_matrix(d)
offdiag1 = linalg.diagonal_matrix(e, 1)
offdiag2 = linalg.diagonal_matrix(e, -1)
return diag + offdiag1 + offdiag2


@testing.parametrize("n", [50])
@testing.parametrize("num_significant_eigvals", [4])
@testing.parametrize("order", [6]) # ~1.5 * num_significant_eigvals
Expand All @@ -107,9 +100,7 @@ def test_validate_unit_norm(A, order):
# Not normalized!
v0 = prng.normal(key, shape=(n,)) + 1.0

algorithm = lanczos.alg_tridiag_full_reortho(
lambda v: A @ v, order, validate_unit_2_norm=True
)
algorithm = decomp.tridiag_sym(lambda v: A @ v, order, validate_unit_2_norm=True)
Q, (d_m, e_m) = algorithm(v0)

# Since v0 is not normalized, all inputs are NaN
Expand Down
Loading

0 comments on commit cb86c32

Please sign in to comment.