Skip to content

Commit

Permalink
Enable docstring-tests and write some missing docstrings (#216)
Browse files Browse the repository at this point in the history
  • Loading branch information
pnkraemer authored Sep 10, 2024
1 parent 64d31c0 commit 2406b9f
Show file tree
Hide file tree
Showing 5 changed files with 22 additions and 3 deletions.
6 changes: 3 additions & 3 deletions matfree/funm.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,16 @@
"""Matrix-free implementations of functions of matrices.
r"""Matrix-free implementations of functions of matrices.
This includes matrix-function-vector products
$$
(f, A, v, p) \\mapsto f(A(p))v
(f, A, v, p) \mapsto f(A(p))v
$$
as well as matrix-function extensions for stochastic trace estimation,
which provide
$$
(f, A, v, p) \\mapsto v^\\top f(A(p))v.
(f, A, v, p) \mapsto v^\top f(A(p))v.
$$
Plug these integrands into
Expand Down
10 changes: 10 additions & 0 deletions matfree/test_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,31 +31,41 @@ def asymmetric_matrix_from_singular_values(vals, /, nrows, ncols):


def to_dense_bidiag(d, e, /, offset=1):
"""Materialize a bidiagonal matrix."""
diag = linalg.diagonal_matrix(d)
offdiag = linalg.diagonal_matrix(e, offset=offset)
return diag + offdiag


def to_dense_tridiag_sym(d, e, /):
"""Materialize a symmetric tridiagonal matrix."""
diag = linalg.diagonal_matrix(d)
offdiag1 = linalg.diagonal_matrix(e, offset=1)
offdiag2 = linalg.diagonal_matrix(e, offset=-1)
return diag + offdiag1 + offdiag2


def tree_random_like(key, tree, *, generate_func=prng.normal):
"""Fill a tree with random values."""
flat, unflatten = tree_util.ravel_pytree(tree)
flat_like = generate_func(key, shape=flat.shape, dtype=flat.dtype)
return unflatten(flat_like)


def assert_columns_orthonormal(Q, /):
"""Assert that the columns in a matrix are orthonormal."""
eye_like = Q.T @ Q
ref = np.eye(len(eye_like))
assert_allclose(eye_like, ref)


def assert_allclose(a, b, /):
"""Assert that two arrays are close.
This function uses a different default tolerance to
jax.numpy.allclose. Instead of fixing values, the tolerance
depends on the floating-point precision of the input variables.
"""
a = np.asarray(a)
b = np.asarray(b)
tol = np.sqrt(np.finfo_eps(np.dtype(b)))
Expand Down
7 changes: 7 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,8 @@ select = [
"EM",
# tryceratops:
"TRY",
# Docstrings:
"D",
]
ignore = [
# warning: `one-blank-line-before-class` (D203) and `no-blank-line-before-class` (D211) are incompatible.
Expand All @@ -125,6 +127,11 @@ line-ending = "lf"
quote-style = "double"
skip-magic-trailing-comma = true

[tool.ruff.per-file-ignores]
# Ignore all directories named `tests`.
"tests/**" = ["D"]
"matfree/backend/**" = ["D"]

[tool.ruff.lint.isort]
split-on-trailing-comma = false

Expand Down
1 change: 1 addition & 0 deletions tutorials/4_control_variates.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@


def matvec_ctrl(v):
"""Evaluate a matrix-vector product with a control variate."""
return A @ v - diagonal_ctrl * v


Expand Down
1 change: 1 addition & 0 deletions tutorials/6_low_memory_trace_estimation.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@


def large_matvec(v):
"""Evaluate a (dummy for a) large matrix-vector product."""
return 1.2345 * v


Expand Down

0 comments on commit 2406b9f

Please sign in to comment.