From c9ac24d20a13dff99f7a8eb1a25684ec0ab09665 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Nicholas=20Kr=C3=A4mer?= Date: Tue, 28 May 2024 08:47:59 +0200 Subject: [PATCH] Add citation information to matfree's Arnoldi and Lanczos iterations (#197) * Move citation information to the top of the README * Add the bibtex entry to Lanczos and Arnoldi implementations * Move bidiagonalisation to the bottom of the file so that Arnoldi and Lanczos are next to each other * Mention that bidiagonalisation is not differentiable * Put a date on the deprecation policy --- README.md | 28 +++- matfree/bounds.py | 2 +- matfree/decomp.py | 335 +++++++++++++++++++++++++++------------------- matfree/pinv.py | 3 +- 4 files changed, 222 insertions(+), 146 deletions(-) diff --git a/README.md b/README.md index 3d146cd..566ac07 100644 --- a/README.md +++ b/README.md @@ -5,14 +5,15 @@ [![image](https://img.shields.io/pypi/l/matfree.svg)](https://pypi.python.org/pypi/matfree) [![image](https://img.shields.io/pypi/pyversions/matfree.svg)](https://pypi.python.org/pypi/matfree) -Randomised and deterministic matrix-free methods for trace estimation, matrix functions, and/or matrix factorisations. +Randomised and deterministic matrix-free methods for trace estimation, functions of matrices, and/or matrix factorisations. Builds on [JAX](https://jax.readthedocs.io/en/latest/). - ⚡ Stochastic **trace estimation** including batching, control variates, and uncertainty quantification -- ⚡ A stand-alone implementation of **stochastic Lanczos quadrature** +- ⚡ A stand-alone implementation of **stochastic Lanczos quadrature** for traces of functions of matrices - ⚡ Matrix-decomposition algorithms for **large sparse eigenvalue problems**: tridiagonalisation, bidiagonalisation, Hessenberg factorisation via Lanczos and Arnoldi iterations -- ⚡ Polynomial methods for approximating **functions of large matrices** +- ⚡ Chebyshev, Lanczos, and Arnoldi-based methods for approximating **functions of large matrices** +- ⚡ **Gradients of functions of large matrices** (like in [this paper](https://arxiv.org/abs/2405.17277)) via differentiable Lanczos and Arnoldi iterations - ⚡ Partial Cholesky **preconditioners** with and without pivoting and many other things. @@ -102,6 +103,27 @@ These tutorials include, among other things: [_Let us know_](https://github.com/pnkraemer/matfree/issues) what you use matfree for! +**Citation** + +Thank you for using Matfree! +If you are using Matfree's differentiable Lanczos or Arnoldi iterations, then you +are using the algorithms from [this paper](https://arxiv.org/abs/2405.17277). +We would appreciate if you cited it as follows: + +```bibtex +@article{kraemer2024gradients, + title={Gradients of functions of large matrices}, + author={Kr\"amer, Nicholas and Moreno-Mu\~noz, Pablo and Roy, Hrittik and Hauberg S\o{}ren}, + journal={arXiv preprint arXiv:2405.17277}, + year={2024} +} +``` + +Some of Matfree's docstrings contain additional bibliographic information. +For example, the functions in `matfree.bounds` link to bibtex entries for the articles associated with each bound. +Go check out the [API documentation](https://pnkraemer.github.io/matfree/). + + ## Use Matfree's continuous integration diff --git a/matfree/bounds.py b/matfree/bounds.py index 3465657..dfceefe 100644 --- a/matfree/bounds.py +++ b/matfree/bounds.py @@ -18,7 +18,7 @@ def baigolub96_logdet_spd(bound_spectrum, /, nrows, trace, norm_frobenius_square ??? note "BibTex for Bai and Golub (1996)" - ```tex + ```bibtex @article{bai1996bounds, title={Bounds for the trace of the inverse and the determinant of symmetric positive definite matrices}, diff --git a/matfree/decomp.py b/matfree/decomp.py index 6a09e38..289d98f 100644 --- a/matfree/decomp.py +++ b/matfree/decomp.py @@ -73,13 +73,32 @@ def tridiag_sym( ): """Construct an implementation of **tridiagonalisation**. - Uses pre-allocation and full reorthogonalisation. + Uses pre-allocation, and full reorthogonalisation if `reortho` is set to `"full"`. + It tends to be a good idea to use full reorthogonalisation. This algorithm assumes a **symmetric matrix**. Decompose a matrix into a product of orthogonal-**tridiagonal**-orthogonal matrices. Use this algorithm for approximate **eigenvalue** decompositions. + Setting `custom_vjp` to `True` implies using efficient, numerically stable + gradients of the Lanczos iteration according to what has been proposed by + Krämer et al. (2024). + These gradients are exact, so there is little reason not to use them. + If you use this configuration, please consider + citing Krämer et al. (2024; bibtex below). + + ??? note "BibTex for Krämer et al. (2024)" + ```bibtex + @article{kraemer2024gradients, + title={Gradients of functions of large matrices}, + author={Kr\"amer, Nicholas and Moreno-Mu\\~noz, Pablo and + Roy, Hrittik and Hauberg S\\o{}ren}, + journal={arXiv preprint arXiv:2405.17277}, + year={2024} + } + ``` + """ if reortho == "full": @@ -276,146 +295,6 @@ def _tridiag_adjoint_step( return (xs_all, xi, lambda_), (gradient_increment, lambda_, mu, nu, xi) -def bidiag( - Av: Callable, vA: Callable, depth, /, matrix_shape, validate_unit_2_norm=False -): - """Construct an implementation of **bidiagonalisation**. - - Uses pre-allocation and full reorthogonalisation. - - Works for **arbitrary matrices**. No symmetry required. - - Decompose a matrix into a product of orthogonal-**bidiagonal**-orthogonal matrices. - Use this algorithm for approximate **singular value** decompositions. - """ - nrows, ncols = matrix_shape - max_depth = min(nrows, ncols) - 1 - if depth > max_depth or depth < 0: - msg1 = f"Depth {depth} exceeds the matrix' dimensions. " - msg2 = f"Expected: 0 <= depth <= min(nrows, ncols) - 1 = {max_depth} " - msg3 = f"for a matrix with shape {matrix_shape}." - raise ValueError(msg1 + msg2 + msg3) - - class State(containers.NamedTuple): - i: int - Us: Array - Vs: Array - alphas: Array - betas: Array - beta: Array - vk: Array - - def init(init_vec: Array) -> State: - if validate_unit_2_norm: - init_vec = _validate_unit_2_norm(init_vec) - - alphas = np.zeros((depth + 1,)) - betas = np.zeros((depth + 1,)) - Us = np.zeros((depth + 1, nrows)) - Vs = np.zeros((depth + 1, ncols)) - v0, _ = _normalise(init_vec) - return State(0, Us, Vs, alphas, betas, np.zeros(()), v0) - - def apply(state: State, *parameters) -> State: - i, Us, Vs, alphas, betas, beta, vk = state - Vs = Vs.at[i].set(vk) - betas = betas.at[i].set(beta) - - uk = Av(vk, *parameters) - beta * Us[i - 1] - uk, alpha = _normalise(uk) - uk, *_ = _gram_schmidt_classical(uk, Us) # full reorthogonalisation - Us = Us.at[i].set(uk) - alphas = alphas.at[i].set(alpha) - - vk = vA(uk, *parameters) - alpha * vk - vk, beta = _normalise(vk) - vk, *_ = _gram_schmidt_classical(vk, Vs) # full reorthogonalisation - - return State(i + 1, Us, Vs, alphas, betas, beta, vk) - - def extract(state: State, /): - _, uk_all, vk_all, alphas, betas, beta, vk = state - return uk_all.T, (alphas, betas[1:]), vk_all, (beta, vk) - - alg = _LanczosAlg( - init=init, step=apply, extract=extract, lower_upper=(0, depth + 1) - ) - return func.partial(_decompose_fori_loop, algorithm=alg) - - -def _validate_unit_2_norm(v, /): - # todo: replace this functionality with normalising internally. - # - # Lanczos assumes a unit-2-norm vector as an input - # We cannot raise an error based on values of the init_vec, - # but we can make it obvious that the result is unusable. - is_not_normalized = np.abs(linalg.vector_norm(v) - 1.0) > 10 * np.finfo_eps(v.dtype) - return control_flow.cond( - is_not_normalized, lambda s: np.nan() * np.ones_like(s), lambda s: s, v - ) - - -def _gram_schmidt_classical(vec, vectors): # Gram-Schmidt - vec, coeffs = control_flow.scan(_gram_schmidt_classical_step, vec, xs=vectors) - vec, length = _normalise(vec) - return vec, length, coeffs - - -def _gram_schmidt_classical_step(vec1, vec2): - coeff = linalg.inner(vec1, vec2) - vec_ortho = vec1 - coeff * vec2 - return vec_ortho, coeff - - -def _normalise(vec): - length = linalg.vector_norm(vec) - return vec / length, length - - -def _decompose_fori_loop(v0, *parameters, algorithm: _LanczosAlg): - r"""Decompose a matrix purely based on matvec-products with A. - - The behaviour of this function is equivalent to - - ```python - def decompose(v0, *matvec_funs, algorithm): - init, step, extract, (lower, upper) = algorithm - state = init(v0) - for _ in range(lower, upper): - state = step(state, *matvec_funs) - return extract(state) - ``` - - but the implementation uses JAX' fori_loop. - """ - init, step, extract, (lower, upper) = algorithm - init_val = init(v0) - - def body_fun(_, s): - return step(s, *parameters) - - result = control_flow.fori_loop(lower, upper, body_fun=body_fun, init_val=init_val) - return extract(result) - - -def _bidiagonal_dense(d, e): - diag = linalg.diagonal_matrix(d) - offdiag = linalg.diagonal_matrix(e, 1) - return diag + offdiag - - -def _eigh_tridiag_sym(diag, off_diag): - # todo: once jax supports eigh_tridiagonal(eigvals_only=False), - # use it here. Until then: an eigen-decomposition of size (order + 1) - # does not hurt too much... - diag = linalg.diagonal_matrix(diag) - offdiag1 = linalg.diagonal_matrix(off_diag, -1) - offdiag2 = linalg.diagonal_matrix(off_diag, 1) - dense_matrix = diag + offdiag1 + offdiag2 - eigvals, eigvecs = linalg.eigh(dense_matrix) - return eigvals, eigvecs - - def hessenberg( matvec, krylov_depth, @@ -425,6 +304,32 @@ def hessenberg( custom_vjp: bool = True, reortho_vjp: str = "match", ): + """Construct a **Hessenberg-factorisation** via the Arnoldi iteration. + + Uses pre-allocation, and full reorthogonalisation if `reortho` is set to `"full"`. + It tends to be a good idea to use full reorthogonalisation. + + This algorithm works for **arbitrary matrices**. + + Setting `custom_vjp` to `True` implies using efficient, numerically stable + gradients of the Arnoldi iteration according to what has been proposed by + Krämer et al. (2024). + These gradients are exact, so there is little reason not to use them. + If you use this configuration, + please consider citing Krämer et al. (2024; bibtex below). + + ??? note "BibTex for Krämer et al. (2024)" + ```bibtex + @article{kraemer2024gradients, + title={Gradients of functions of large matrices}, + author={Kr\"amer, Nicholas and Moreno-Mu\\~noz, Pablo and + Roy, Hrittik and Hauberg S\\o{}ren}, + journal={arXiv preprint arXiv:2405.17277}, + year={2024} + } + ``` + + """ reortho_expected = ["none", "full"] if reortho not in reortho_expected: msg = f"Unexpected input for {reortho}: either of {reortho_expected} expected." @@ -635,3 +540,151 @@ def _hessenberg_adjoint_step( def _extract_diag(x, offset=0): diag = linalg.diagonal(x, offset=offset) return linalg.diagonal_matrix(diag, offset=offset) + + +def bidiag( + Av: Callable, vA: Callable, depth, /, matrix_shape, validate_unit_2_norm=False +): + """Construct an implementation of **bidiagonalisation**. + + Uses pre-allocation and full reorthogonalisation. + + Works for **arbitrary matrices**. No symmetry required. + + Decompose a matrix into a product of orthogonal-**bidiagonal**-orthogonal matrices. + Use this algorithm for approximate **singular value** decompositions. + + ??? note "A note about differentiability" + Unlike [tridiag_sym][matfree.decomp.tridiag_sym] or + [hessenberg][matfree.decomp.hessenberg], this function's reverse-mode + derivatives are very efficient. Custom gradients for bidiagonalisation + are a work in progress, and if you need to differentiate the decompositions, + consider using [tridiag_sym][matfree.decomp.tridiag_sym] for the time being. + + """ + nrows, ncols = matrix_shape + max_depth = min(nrows, ncols) - 1 + if depth > max_depth or depth < 0: + msg1 = f"Depth {depth} exceeds the matrix' dimensions. " + msg2 = f"Expected: 0 <= depth <= min(nrows, ncols) - 1 = {max_depth} " + msg3 = f"for a matrix with shape {matrix_shape}." + raise ValueError(msg1 + msg2 + msg3) + + class State(containers.NamedTuple): + i: int + Us: Array + Vs: Array + alphas: Array + betas: Array + beta: Array + vk: Array + + def init(init_vec: Array) -> State: + if validate_unit_2_norm: + init_vec = _validate_unit_2_norm(init_vec) + + alphas = np.zeros((depth + 1,)) + betas = np.zeros((depth + 1,)) + Us = np.zeros((depth + 1, nrows)) + Vs = np.zeros((depth + 1, ncols)) + v0, _ = _normalise(init_vec) + return State(0, Us, Vs, alphas, betas, np.zeros(()), v0) + + def apply(state: State, *parameters) -> State: + i, Us, Vs, alphas, betas, beta, vk = state + Vs = Vs.at[i].set(vk) + betas = betas.at[i].set(beta) + + uk = Av(vk, *parameters) - beta * Us[i - 1] + uk, alpha = _normalise(uk) + uk, *_ = _gram_schmidt_classical(uk, Us) # full reorthogonalisation + Us = Us.at[i].set(uk) + alphas = alphas.at[i].set(alpha) + + vk = vA(uk, *parameters) - alpha * vk + vk, beta = _normalise(vk) + vk, *_ = _gram_schmidt_classical(vk, Vs) # full reorthogonalisation + + return State(i + 1, Us, Vs, alphas, betas, beta, vk) + + def extract(state: State, /): + _, uk_all, vk_all, alphas, betas, beta, vk = state + return uk_all.T, (alphas, betas[1:]), vk_all, (beta, vk) + + alg = _LanczosAlg( + init=init, step=apply, extract=extract, lower_upper=(0, depth + 1) + ) + return func.partial(_decompose_fori_loop, algorithm=alg) + + +def _validate_unit_2_norm(v, /): + # todo: replace this functionality with normalising internally. + # + # Lanczos assumes a unit-2-norm vector as an input + # We cannot raise an error based on values of the init_vec, + # but we can make it obvious that the result is unusable. + is_not_normalized = np.abs(linalg.vector_norm(v) - 1.0) > 10 * np.finfo_eps(v.dtype) + return control_flow.cond( + is_not_normalized, lambda s: np.nan() * np.ones_like(s), lambda s: s, v + ) + + +def _gram_schmidt_classical(vec, vectors): # Gram-Schmidt + vec, coeffs = control_flow.scan(_gram_schmidt_classical_step, vec, xs=vectors) + vec, length = _normalise(vec) + return vec, length, coeffs + + +def _gram_schmidt_classical_step(vec1, vec2): + coeff = linalg.inner(vec1, vec2) + vec_ortho = vec1 - coeff * vec2 + return vec_ortho, coeff + + +def _normalise(vec): + length = linalg.vector_norm(vec) + return vec / length, length + + +def _decompose_fori_loop(v0, *parameters, algorithm: _LanczosAlg): + r"""Decompose a matrix purely based on matvec-products with A. + + The behaviour of this function is equivalent to + + ```python + def decompose(v0, *matvec_funs, algorithm): + init, step, extract, (lower, upper) = algorithm + state = init(v0) + for _ in range(lower, upper): + state = step(state, *matvec_funs) + return extract(state) + ``` + + but the implementation uses JAX' fori_loop. + """ + init, step, extract, (lower, upper) = algorithm + init_val = init(v0) + + def body_fun(_, s): + return step(s, *parameters) + + result = control_flow.fori_loop(lower, upper, body_fun=body_fun, init_val=init_val) + return extract(result) + + +def _bidiagonal_dense(d, e): + diag = linalg.diagonal_matrix(d) + offdiag = linalg.diagonal_matrix(e, 1) + return diag + offdiag + + +def _eigh_tridiag_sym(diag, off_diag): + # todo: once jax supports eigh_tridiagonal(eigvals_only=False), + # use it here. Until then: an eigen-decomposition of size (order + 1) + # does not hurt too much... + diag = linalg.diagonal_matrix(diag) + offdiag1 = linalg.diagonal_matrix(off_diag, -1) + offdiag2 = linalg.diagonal_matrix(off_diag, 1) + dense_matrix = diag + offdiag1 + offdiag2 + eigvals, eigvecs = linalg.eigh(dense_matrix) + return eigvals, eigvecs diff --git a/matfree/pinv.py b/matfree/pinv.py index 69c6a4d..7ff496d 100644 --- a/matfree/pinv.py +++ b/matfree/pinv.py @@ -8,7 +8,8 @@ def _warn_deprecated(): msg = "The module matfree.pinv has been deprecated and will be removed soon. " msg += "The removal will happen either in v0.0.17 or in v0.1.0, " - msg += "depending on what comes first. " + msg += "or on the 15th of June 2024, " + msg += "depending on which of the three comes first. " msg += "If your code relies on matfree.pinv, create an issue *now*." warnings.warn(msg, DeprecationWarning, stacklevel=1)