diff --git a/lineax/__init__.py b/lineax/__init__.py index f7c5ee9..054971a 100644 --- a/lineax/__init__.py +++ b/lineax/__init__.py @@ -19,6 +19,7 @@ AbstractLinearOperator as AbstractLinearOperator, AddLinearOperator as AddLinearOperator, AuxLinearOperator as AuxLinearOperator, + BlockTridiagonalLinearOperator as BlockTridiagonalLinearOperator, ComposedLinearOperator as ComposedLinearOperator, conj as conj, diagonal as diagonal, @@ -27,6 +28,7 @@ FunctionLinearOperator as FunctionLinearOperator, has_unit_diagonal as has_unit_diagonal, IdentityLinearOperator as IdentityLinearOperator, + is_blocktridiagonal as is_blocktridiagonal, is_diagonal as is_diagonal, is_lower_triangular as is_lower_triangular, is_negative_semidefinite as is_negative_semidefinite, diff --git a/lineax/_operator.py b/lineax/_operator.py index 549f6ac..e76c041 100644 --- a/lineax/_operator.py +++ b/lineax/_operator.py @@ -51,6 +51,7 @@ ) from ._norm import default_floating_dtype from ._tags import ( + blocktridiagonal_tag, diagonal_tag, lower_triangular_tag, negative_semidefinite_tag, @@ -847,6 +848,89 @@ def out_structure(self): return jax.ShapeDtypeStruct(shape=(size,), dtype=self.diagonal.dtype) +class BlockTridiagonalLinearOperator(AbstractLinearOperator): + """As [`lineax.MatrixLinearOperator`][], but for specifically a block tridiagonal + matrix. + """ + + diagonal: Inexact[Array, "size N N"] + lower_diagonal: Inexact[Array, "size-1 N N"] + upper_diagonal: Inexact[Array, "size-1 N N"] + + def __init__( + self, + diagonal: Inexact[Array, "size N N"], + lower_diagonal: Inexact[Array, "size-1 N N"], + upper_diagonal: Inexact[Array, "size-1 N N"], + ): + """**Arguments:** + + - `diagonal`: A rank-3 JAX array. This is the diagonal of the matrix made + up of a number of NxN blocks. + - `lower_diagonal`: A rank-3 JAX array. This is the lower diagonal of the + matrix. + - `upper_diagonal`: A rank-3 JAX array. This is the upper diagonal of the + matrix. + + If `diagonal` has shape `(a, N, N)` then `lower_diagonal` and + `upper_diagonal` should both have shape `(a - 1, N, N)`. + """ + self.diagonal = inexact_asarray(diagonal) + self.lower_diagonal = inexact_asarray(lower_diagonal) + self.upper_diagonal = inexact_asarray(upper_diagonal) + (size, N, M) = self.diagonal.shape + if N != M: + raise ValueError(f"expecting square blocks, got {N} by {M} on diagonal") + if self.lower_diagonal.shape != (size - 1, N, N): + raise ValueError("lower_diagonal and diagonal do not have consistent shape") + if self.upper_diagonal.shape != (size - 1, N, N): + raise ValueError("upper_diagonal and diagonal do not have consistent shape") + + def mv(self, vector): + size, N, M = jnp.shape(self.diagonal) + v = vector.reshape(size, N) + a = jnp.einsum("ijk,ik -> ij", self.upper_diagonal, v[1:, :]).flatten() + b = jnp.einsum("ijk,ik -> ij", self.diagonal, v[:, :]).flatten() + c = jnp.einsum("ijk,ik -> ij", self.lower_diagonal, v[:-1, :]).flatten() + return b.at[:-N].add(a).at[N:].add(c) + + def as_matrix(self): + size, N, M = jnp.shape(self.diagonal) + zeros_block = jnp.zeros((N, N), self.diagonal.dtype) + block_matrix = jnp.array( + [ + [ + zeros_block, + ] + * size, + ] + * size + ) + arange = jnp.arange(size) + block_matrix = block_matrix.at[arange, arange].set(self.diagonal) + block_matrix = block_matrix.at[arange[1:], arange[:-1]].set(self.lower_diagonal) + block_matrix = block_matrix.at[arange[:-1], arange[1:]].set(self.upper_diagonal) + + blocked_concat = [jnp.concatenate(block, axis=1) for block in block_matrix] + matrix = jnp.concatenate(blocked_concat, axis=0) + return matrix + + def transpose(self): + return BlockTridiagonalLinearOperator( + jnp.transpose(self.diagonal, axes=[0, 2, 1]), + jnp.transpose(self.upper_diagonal, axes=[0, 2, 1]), + jnp.transpose(self.lower_diagonal, axes=[0, 2, 1]), + ) + + def in_structure(self): + size, N, _ = jnp.shape(self.diagonal) + return jax.ShapeDtypeStruct(shape=(N * size,), dtype=self.diagonal.dtype) + + def out_structure(self): + size, N, _ = jnp.shape(self.diagonal) + return jax.ShapeDtypeStruct(shape=(N * size,), dtype=self.diagonal.dtype) + + class TaggedLinearOperator(AbstractLinearOperator): """Wraps another linear operator and specifies that it has certain tags, e.g. representing symmetry. @@ -1202,6 +1286,7 @@ def linearise(operator: AbstractLinearOperator) -> AbstractLinearOperator: @linearise.register(IdentityLinearOperator) @linearise.register(DiagonalLinearOperator) @linearise.register(TridiagonalLinearOperator) +@linearise.register(BlockTridiagonalLinearOperator) def _(operator): return operator @@ -1340,6 +1425,7 @@ def diagonal(operator: AbstractLinearOperator) -> Shaped[Array, " size"]: @diagonal.register(PyTreeLinearOperator) @diagonal.register(JacobianLinearOperator) @diagonal.register(FunctionLinearOperator) +@diagonal.register(BlockTridiagonalLinearOperator) def _(operator): return jnp.diag(operator.as_matrix()) @@ -1394,6 +1480,7 @@ def tridiagonal( @tridiagonal.register(PyTreeLinearOperator) @tridiagonal.register(JacobianLinearOperator) @tridiagonal.register(FunctionLinearOperator) +@tridiagonal.register(BlockTridiagonalLinearOperator) def _(operator): matrix = operator.as_matrix() assert matrix.ndim == 2 @@ -1423,6 +1510,70 @@ def _(operator): return operator.diagonal, operator.lower_diagonal, operator.upper_diagonal +# blocktridiagonal + + +@ft.singledispatch +def blocktridiagonal( + operator: AbstractLinearOperator, +) -> tuple[ + Shaped[Array, " size N N"], + Shaped[Array, " size-1 N N"], + Shaped[Array, " size-1 N N"], +]: + """Extracts the blocked diagonal, lower diagonal, and upper diagonal, from a linear + operator. Returns three vectors. + + **Arguments:** + + - `operator`: a linear operator. + + **Returns:** + + A 3-tuple, consisting of: + + - The block diagonal of the matrix, represented as a vector. + - The block lower diagonal of the matrix, represented as a vector. + - The block upper diagonal of the matrix, represented as a vector. + + If the diagonal has shape `(a, N, N)` then the lower and upper diagonals + will have shape `(a - 1, N, N)`. + + For most operators this block extraction is not possible + """ + _default_not_implemented("blocktridiagonal", operator) + + +@blocktridiagonal.register(BlockTridiagonalLinearOperator) +def _(operator): + return operator.diagonal, operator.lower_diagonal, operator.upper_diagonal + + +@blocktridiagonal.register(IdentityLinearOperator) +def _(operator): + size = operator.in_size() + diagonal = jnp.ones((size, 1, 1)) + off_diagonal = jnp.zeros((size - 1, 1, 1)) + return diagonal, off_diagonal, off_diagonal + + +@blocktridiagonal.register(DiagonalLinearOperator) +def _(operator): + (size,) = operator.diagonal.shape + off_diagonal = jnp.zeros((size - 1, 1, 1)) + return operator.diagonal.reshape(size, 1, 1), off_diagonal, off_diagonal + + +@blocktridiagonal.register(TridiagonalLinearOperator) +def _(operator): + (size,) = operator.diagonal.shape + return ( + operator.diagonal.reshape(size, 1, 1), + operator.lower_diagonal.reshape(size - 1, 1, 1), + operator.upper_diagonal.reshape(size - 1, 1, 1), + ) + + # is_symmetric @@ -1471,6 +1622,7 @@ def _(operator): @is_symmetric.register(TridiagonalLinearOperator) +@is_symmetric.register(BlockTridiagonalLinearOperator) def _(operator): return False @@ -1511,6 +1663,7 @@ def _(operator): @is_diagonal.register(TridiagonalLinearOperator) +@is_diagonal.register(BlockTridiagonalLinearOperator) def _(operator): return False @@ -1551,6 +1704,48 @@ def _(operator): return True +@is_tridiagonal.register(BlockTridiagonalLinearOperator) +def _(operator): + return False + + +# is_blocktridiagonal + + +@ft.singledispatch +def is_blocktridiagonal(operator: AbstractLinearOperator) -> bool: + """Returns whether an operator is marked as blocktridiagonal. + + See [the documentation on linear operator tags](../api/tags.md) for more + information. + + **Arguments:** + + - `operator`: a linear operator. + + **Returns:** + + Either `True` or `False.` + """ + _default_not_implemented("is_blocktridiagonal", operator) + + +@is_blocktridiagonal.register(BlockTridiagonalLinearOperator) +def _(operator): + return True + + +@is_blocktridiagonal.register(MatrixLinearOperator) +@is_blocktridiagonal.register(PyTreeLinearOperator) +@is_blocktridiagonal.register(JacobianLinearOperator) +@is_blocktridiagonal.register(FunctionLinearOperator) +@is_blocktridiagonal.register(IdentityLinearOperator) +@is_blocktridiagonal.register(DiagonalLinearOperator) +@is_blocktridiagonal.register(TridiagonalLinearOperator) +def _(operator): + return False + + # has_unit_diagonal @@ -1587,6 +1782,7 @@ def _(operator): @has_unit_diagonal.register(DiagonalLinearOperator) @has_unit_diagonal.register(TridiagonalLinearOperator) +@has_unit_diagonal.register(BlockTridiagonalLinearOperator) def _(operator): # TODO: refine this return False @@ -1628,6 +1824,7 @@ def _(operator): @is_lower_triangular.register(TridiagonalLinearOperator) +@is_lower_triangular.register(BlockTridiagonalLinearOperator) def _(operator): return False @@ -1668,6 +1865,7 @@ def _(operator): @is_upper_triangular.register(TridiagonalLinearOperator) +@is_upper_triangular.register(BlockTridiagonalLinearOperator) def _(operator): return False @@ -1708,6 +1906,7 @@ def _(operator): @is_positive_semidefinite.register(DiagonalLinearOperator) @is_positive_semidefinite.register(TridiagonalLinearOperator) +@is_positive_semidefinite.register(BlockTridiagonalLinearOperator) def _(operator): # TODO: refine this return False @@ -1749,6 +1948,7 @@ def _(operator): @is_negative_semidefinite.register(DiagonalLinearOperator) @is_negative_semidefinite.register(TridiagonalLinearOperator) +@is_negative_semidefinite.register(BlockTridiagonalLinearOperator) def _(operator): # TODO: refine this return False @@ -1896,6 +2096,7 @@ def _(operator): is_lower_triangular, is_upper_triangular, is_tridiagonal, + is_blocktridiagonal, ): @check.register(TangentLinearOperator) @@ -1943,6 +2144,7 @@ def _(operator, check=check): (is_positive_semidefinite, positive_semidefinite_tag), (is_negative_semidefinite, negative_semidefinite_tag), (is_tridiagonal, tridiagonal_tag), + (is_blocktridiagonal, blocktridiagonal_tag), ): @check.register(TaggedLinearOperator) @@ -1958,6 +2160,7 @@ def _(operator, check=check, tag=tag): is_positive_semidefinite, is_negative_semidefinite, is_tridiagonal, + is_blocktridiagonal, ): @check.register(AddLinearOperator) @@ -1978,6 +2181,7 @@ def _(operator): is_positive_semidefinite, is_negative_semidefinite, is_tridiagonal, + is_blocktridiagonal, ): @check.register(ComposedLinearOperator) diff --git a/lineax/_solve.py b/lineax/_solve.py index 45590e6..be4003b 100644 --- a/lineax/_solve.py +++ b/lineax/_solve.py @@ -34,6 +34,7 @@ AbstractLinearOperator, conj, IdentityLinearOperator, + is_blocktridiagonal, is_diagonal, is_lower_triangular, is_negative_semidefinite, @@ -490,6 +491,7 @@ def conj( _well_posed_diagonal_token = eqxi.str2jax("well_posed_diagonal_token") _tridiagonal_token = eqxi.str2jax("tridiagonal_token") _triangular_token = eqxi.str2jax("triangular_token") +_blocktridiagonal_token = eqxi.str2jax("blocktridiagonal_token") _cholesky_token = eqxi.str2jax("cholesky_token") _lu_token = eqxi.str2jax("lu_token") _svd_token = eqxi.str2jax("svd_token") @@ -509,6 +511,7 @@ def _lookup(token) -> AbstractLinearSolver: well_posed=True ), _tridiagonal_token: _solver.Tridiagonal(), # pyright: ignore + _blocktridiagonal_token: _solver.BlockTridiagonal(), # pyright: ignore _triangular_token: _solver.Triangular(), # pyright: ignore _cholesky_token: _solver.Cholesky(), # pyright: ignore _lu_token: _solver.LU(), # pyright: ignore @@ -527,6 +530,7 @@ class AutoLinearSolver(AbstractLinearSolver[_AutoLinearSolverState]): - If `well_posed=True`: - If the operator is diagonal, then use [`lineax.Diagonal`][]. - If the operator is tridiagonal, then use [`lineax.Tridiagonal`][]. + - If the operator is block tridiagonal, then use [`lineax.BlockTridiagonal`][]. - If the operator is triangular, then use [`lineax.Triangular`][]. - If the matrix is positive or negative definite, then use [`lineax.Cholesky`][]. @@ -546,6 +550,7 @@ class AutoLinearSolver(AbstractLinearSolver[_AutoLinearSolverState]): - If the operator is non-square, then use [`lineax.QR`][]. - If the operator is diagonal, then use [`lineax.Diagonal`][]. - If the operator is tridiagonal, then use [`lineax.Tridiagonal`][]. + - If the operator is block tridiagonal, then use [`lineax.BlockTridiagonal`][]. - If the operator is triangular, then use [`lineax.Triangular`][]. - If the matrix is positive or negative definite, then use [`lineax.Cholesky`][]. @@ -571,6 +576,8 @@ def _select_solver(self, operator: AbstractLinearOperator): token = _well_posed_diagonal_token elif is_tridiagonal(operator): token = _tridiagonal_token + elif is_blocktridiagonal(operator): + token = _blocktridiagonal_token elif is_lower_triangular(operator) or is_upper_triangular(operator): token = _triangular_token elif is_positive_semidefinite(operator) or is_negative_semidefinite( @@ -592,6 +599,8 @@ def _select_solver(self, operator: AbstractLinearOperator): token = _diagonal_token elif is_tridiagonal(operator): token = _tridiagonal_token + elif is_blocktridiagonal(operator): + token = _blocktridiagonal_token elif is_lower_triangular(operator) or is_upper_triangular(operator): token = _triangular_token elif is_positive_semidefinite(operator) or is_negative_semidefinite( diff --git a/lineax/_solver/__init__.py b/lineax/_solver/__init__.py index 425fc40..c66c628 100644 --- a/lineax/_solver/__init__.py +++ b/lineax/_solver/__init__.py @@ -13,6 +13,7 @@ # limitations under the License. from .bicgstab import BiCGStab as BiCGStab +from .blocktridiagonal import BlockTridiagonal as BlockTridiagonal from .cg import CG as CG, NormalCG as NormalCG from .cholesky import Cholesky as Cholesky from .diagonal import Diagonal as Diagonal diff --git a/lineax/_solver/blocktridiagonal.py b/lineax/_solver/blocktridiagonal.py new file mode 100644 index 0000000..5798d24 --- /dev/null +++ b/lineax/_solver/blocktridiagonal.py @@ -0,0 +1,126 @@ +from typing import Any +from typing_extensions import TypeAlias + +import jax +import jax.lax as lax +import jax.numpy as jnp +from jaxtyping import Array, PyTree + +from .._operator import ( + AbstractLinearOperator, + blocktridiagonal, + is_blocktridiagonal, + MatrixLinearOperator, +) +from .._solution import RESULTS +from .._solve import AbstractLinearSolver, linear_solve +from .misc import ( + pack_structures, + PackedStructures, + ravel_vector, + transpose_packed_structures, + unravel_solution, +) + + +_BlockTridiagonalState: TypeAlias = tuple[tuple[Array, Array, Array], PackedStructures] + + +class BlockTridiagonal(AbstractLinearSolver[_BlockTridiagonalState]): + """Block tridiagonal solver for linear systems, using the Thomas algorithm.""" + + def init(self, operator: AbstractLinearOperator, options: dict[str, Any]): + del options + if operator.in_size() != operator.out_size(): + raise ValueError( + """`BlockTridiagonal` may only be used for linear solves with + square matrices""" + ) + if not is_blocktridiagonal(operator): + raise ValueError( + """`BlockTridiagonal` may only be used for linear solves with + block tridiagonal `matrices`""" + ) + return blocktridiagonal(operator), pack_structures(operator) + + def compute( + self, + state: _BlockTridiagonalState, + vector, + options, + ) -> tuple[PyTree[Array], RESULTS, dict[str, Any]]: + (diagonal, lower_diagonal, upper_diagonal), packed_structures = state + del state, options + vector = ravel_vector(vector, packed_structures) + + # + # Modifications to basic Thomas algorithm to work on block matrices + # notation from: https://en.wikipedia.org/wiki/Tridiagonal_matrix_algorithm + # _p indicates prime, ie. `d_p` is the variable name for d' on wikipedia + # + + block_size = diagonal.shape[1] + size = diagonal.shape[0] + + vector = vector.reshape(size, block_size) + + def blockthomas_scan(prev_cd_carry, bd): + c_p, d_p, step = prev_cd_carry + # the index of `a` doesn't matter at step 0 as + # we won't use it at all. Same for `c` at final step + a_index = jnp.where(step > 0, step - 1, 0) + c_index = jnp.where(step < size, step, 0) + + b, d = bd + a, c = lower_diagonal[a_index, :, :], upper_diagonal[c_index, :, :] + + denom = MatrixLinearOperator(b - jnp.matmul(a, c_p)) + + def matrix_linear_solve_vec(step, x_vec): + y = linear_solve(denom, x_vec).value + step += 1 + return step, y + + _, new_d_p = matrix_linear_solve_vec(0, d - jnp.matmul(a, d_p)) + _, new_c_p = jax.vmap(matrix_linear_solve_vec, in_axes=(None, 0))(0, c.T) + new_c_p = new_c_p.T + return (new_c_p, new_d_p, step + 1), (new_c_p, new_d_p) + + def backsub(prev_x_carry, cd_p): + x_prev, step = prev_x_carry + c_p, d_p = cd_p + x_new = d_p - jnp.dot(c_p, x_prev) + return (x_new, step + 1), x_new + + init_thomas = (jnp.zeros((block_size, block_size)), jnp.zeros(block_size), 0) + init_backsub = (jnp.zeros(block_size), 0) + diag_vec = (diagonal, vector) + _, cd_p = lax.scan(blockthomas_scan, init_thomas, diag_vec) + _, solution = lax.scan(backsub, init_backsub, cd_p, reverse=True) + solution = solution.flatten() + + solution = unravel_solution(solution, packed_structures) + return solution, RESULTS.successful, {} + + def transpose(self, state: _BlockTridiagonalState, options: dict[str, Any]): + (diagonal, lower_diagonal, upper_diagonal), packed_structures = state + transposed_packed_structures = transpose_packed_structures(packed_structures) + transpose_diagonals = ( + jnp.transpose(diagonal, axes=[0, 2, 1]), + jnp.transpose(upper_diagonal, axes=[0, 2, 1]), + jnp.transpose(lower_diagonal, axes=[0, 2, 1]), + ) + transpose_state = (transpose_diagonals, transposed_packed_structures) + return transpose_state, options + + def conj(self, state: _BlockTridiagonalState, options: dict[str, Any]): + (diagonal, lower_diagonal, upper_diagonal), packed_structures = state + conj_diagonals = (diagonal.conj(), lower_diagonal.conj(), upper_diagonal.conj()) + conj_state = (conj_diagonals, packed_structures) + return conj_state, options + + def allow_dependent_columns(self, operator): + return False + + def allow_dependent_rows(self, operator): + return False diff --git a/lineax/_tags.py b/lineax/_tags.py index 8f7cb9a..3270d3a 100644 --- a/lineax/_tags.py +++ b/lineax/_tags.py @@ -24,6 +24,7 @@ def __repr__(self): symmetric_tag = _HasRepr("symmetric_tag") diagonal_tag = _HasRepr("diagonal_tag") tridiagonal_tag = _HasRepr("tridiagonal_tag") +blocktridiagonal_tag = _HasRepr("blocktridiagonal_tag") unit_diagonal_tag = _HasRepr("unit_diagonal_tag") lower_triangular_tag = _HasRepr("lower_triangular_tag") upper_triangular_tag = _HasRepr("upper_triangular_tag") @@ -41,6 +42,7 @@ def __repr__(self): positive_semidefinite_tag, negative_semidefinite_tag, tridiagonal_tag, + blocktridiagonal_tag, ): @transpose_tags_rules.append diff --git a/tests/test_operator.py b/tests/test_operator.py index 36be4ae..ed2d21c 100644 --- a/tests/test_operator.py +++ b/tests/test_operator.py @@ -141,6 +141,29 @@ def test_diagonal(dtype, getkey): assert jnp.allclose(lx.diagonal(operator), matrix_diag) +@pytest.mark.parametrize("dtype", (jnp.float64,)) +def test_blocktridiagonal(dtype, getkey): + tol = 1e-4 + diag = jr.normal(getkey(), (10, 3, 3), dtype=dtype) + upper_diag = jr.normal(getkey(), (9, 3, 3), dtype=dtype) + lower_diag = jr.normal(getkey(), (9, 3, 3), dtype=dtype) + + blocktridiag = lx.BlockTridiagonalLinearOperator(diag, lower_diag, upper_diag) + full_matrix = blocktridiag.as_matrix() + + x_true = jr.normal(getkey(), (30,), dtype=dtype) + + b_BTD = blocktridiag.mv(x_true) + b = full_matrix @ x_true + + assert tree_allclose(b, b_BTD, atol=tol, rtol=tol) + + BTD_soln = lx.linear_solve(blocktridiag, b) + soln = lx.linear_solve(lx.MatrixLinearOperator(full_matrix), b) + + assert tree_allclose(BTD_soln.value, soln.value, atol=tol, rtol=tol) + + @pytest.mark.parametrize("dtype", (jnp.float64,)) def test_is_symmetric(dtype, getkey): matrix = jr.normal(getkey(), (3, 3), dtype=dtype)