Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rename hutchinson functions and delete unused vdC code #163

Merged
merged 3 commits into from
Nov 21, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ Estimate the trace of the matrix:

```python
>>> key = jax.random.PRNGKey(1)
>>> normal = hutchinson.normal(shape=(2,))
>>> normal = hutchinson.sampler_normal(shape=(2,))
>>> trace = hutchinson.trace(matvec, key=key, sample_fun=normal)
>>>
>>> print(jnp.round(trace))
Expand Down
2 changes: 1 addition & 1 deletion docs/benchmarks/control_variates.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def f(x):
_, jvp = func.linearize(f, x0)
J = func.jacfwd(f)(x0)
trace = linalg.trace(J)
sample_fun = hutchinson.normal(shape=(n,), dtype=float)
sample_fun = hutchinson.sampler_normal(shape=(n,), dtype=float)

return (jvp, trace, J), (key, sample_fun)

Expand Down
2 changes: 1 addition & 1 deletion docs/benchmarks/jacobian_squared.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def f(x):
J = func.jacfwd(f)(x0)
A = J @ J @ J @ J
trace = linalg.trace(A)
sample_fun = hutchinson.normal(shape=(n,), dtype=float)
sample_fun = hutchinson.sampler_normal(shape=(n,), dtype=float)

def Av(v):
return jvp(jvp(jvp(jvp(v))))
Expand Down
2 changes: 1 addition & 1 deletion docs/control_variates.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ Imports:
>>> key = jax.random.PRNGKey(1)

>>> matvec = lambda x: a.T @ (a @ x)
>>> sample_fun = hutchinson.normal(shape=(2,))
>>> sample_fun = hutchinson.sampler_normal(shape=(2,))

```

Expand Down
6 changes: 3 additions & 3 deletions docs/higher_moments.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
>>> key = jax.random.PRNGKey(1)

>>> mvp = lambda x: a.T @ (a @ x)
>>> sample_fun = hutchinson.normal(shape=(2,))
>>> sample_fun = hutchinson.sampler_normal(shape=(2,))

```

Expand All @@ -21,7 +21,7 @@ Compute them as such

```python
>>> a = jnp.reshape(jnp.arange(36.0), (6, 6)) / 36
>>> normal = hutchinson.normal(shape=(6,))
>>> normal = hutchinson.sampler_normal(shape=(6,))
>>> mvp = lambda x: a.T @ (a @ x) + x
>>> first, second = hutchinson.trace_moments(mvp, key=key, sample_fun=normal)
>>> print(jnp.round(first, 1))
Expand Down Expand Up @@ -53,7 +53,7 @@ Implement this as follows:

```python
>>> a = jnp.reshape(jnp.arange(36.0), (6, 6)) / 36
>>> sample_fun = hutchinson.normal(shape=(6,))
>>> sample_fun = hutchinson.sampler_normal(shape=(6,))
>>> num_samples = 10_000
>>> mvp = lambda x: a.T @ (a @ x) + x
>>> first, second = hutchinson.trace_moments(
Expand Down
2 changes: 1 addition & 1 deletion docs/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ Estimate the trace of the matrix:

```python
>>> key = jax.random.PRNGKey(1)
>>> normal = hutchinson.normal(shape=(2,))
>>> normal = hutchinson.sampler_normal(shape=(2,))
>>> trace = hutchinson.trace(matvec, key=key, sample_fun=normal)
>>>
>>> print(jnp.round(trace))
Expand Down
6 changes: 3 additions & 3 deletions docs/log_determinants.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,15 @@ Imports:
>>> key = jax.random.PRNGKey(1)

>>> matvec = lambda x: a.T @ (a @ x)
>>> sample_fun = hutchinson.normal(shape=(2,))
>>> sample_fun = hutchinson.sampler_normal(shape=(2,))

```


Estimate log-determinants as such:
```python
>>> a = jnp.reshape(jnp.arange(36.0), (6, 6)) / 36
>>> sample_fun = hutchinson.normal(shape=(6,))
>>> sample_fun = hutchinson.sampler_normal(shape=(6,))
>>> matvec = lambda x: a.T @ (a @ x) + x
>>> order = 3
>>> logdet = slq.logdet_spd(order, matvec, key=key, sample_fun=sample_fun)
Expand All @@ -37,7 +37,7 @@ on arithmetic with $B$; no need to assemble $M$:

```python
>>> a = jnp.reshape(jnp.arange(36.0), (6, 6)) / 36 + jnp.eye(6)
>>> sample_fun = hutchinson.normal(shape=(6,))
>>> sample_fun = hutchinson.sampler_normal(shape=(6,))
>>> matvec = lambda x: (a @ x)
>>> vecmat = lambda x: (a.T @ x)
>>> order = 3
Expand Down
2 changes: 1 addition & 1 deletion docs/pytree_logdeterminants.md
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ Now, we can compute the log-determinant with the flattened inputs as usual:
```python
>>> # Compute the log-determinant
>>> key = jax.random.PRNGKey(seed=1)
>>> sample_fun = hutchinson.normal(shape=f0_flat.shape)
>>> sample_fun = hutchinson.sampler_normal(shape=f0_flat.shape)
>>> order = 3
>>> logdet = slq.logdet_spd(order, matvec, key=key, sample_fun=sample_fun)

Expand Down
2 changes: 1 addition & 1 deletion docs/vector_calculus.md
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ For large-scale problems, it may be the only way of computing Laplacians reliabl
```python
>>> laplacian_dense = divergence_dense(gradient)
>>>
>>> normal = hutchinson.normal(shape=(3,))
>>> normal = hutchinson.sampler_normal(shape=(3,))
>>> key = jax.random.PRNGKey(1)
>>> laplacian_matfree = divergence_matfree(gradient, key=key, sample_fun=normal)
>>>
Expand Down
78 changes: 26 additions & 52 deletions matfree/hutchinson.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
# trace_and_frobeniusnorm(): y=Ax; return (x@y, y@y)


def estimate(
def mc_estimate(
fun: Callable,
/,
*,
Expand All @@ -31,8 +31,8 @@ def estimate(
sample_fun:
Sampling function.
For trace-estimation, use
either [normal(...)][matfree.hutchinson.normal]
or [rademacher(...)][matfree.hutchinson.normal].
either [normal(...)][matfree.hutchinson.sampler_normal]
or [rademacher(...)][matfree.hutchinson.sampler_normal].
num_batches:
Number of batches when computing arithmetic means.
num_samples_per_batch:
Expand All @@ -50,7 +50,7 @@ def estimate(
[one of these functions](https://data-apis.org/array-api/2022.12/API_specification/statistical_functions.html)
would work.
"""
[result] = multiestimate(
[result] = mc_multiestimate(
fun,
key=key,
sample_fun=sample_fun,
Expand All @@ -62,7 +62,7 @@ def estimate(
return result


def multiestimate(
def mc_multiestimate(
fun: Callable,
/,
*,
Expand All @@ -76,7 +76,7 @@ def multiestimate(
"""Compute a Monte-Carlo estimate with multiple summary statistics.

The signature of this function is almost identical to
[estimate(...)][matfree.hutchinson.estimate].
[mc_estimate(...)][matfree.hutchinson.mc_estimate].
The only difference is that statistics_batch and statistics_combine are iterables
of summary statistics (of equal lengths).

Expand All @@ -85,15 +85,15 @@ def multiestimate(
Parameters
----------
fun:
Same as in [estimate(...)][matfree.hutchinson.estimate].
Same as in [mc_estimate(...)][matfree.hutchinson.mc_estimate].
key:
Same as in [estimate(...)][matfree.hutchinson.estimate].
Same as in [mc_estimate(...)][matfree.hutchinson.mc_estimate].
sample_fun:
Same as in [estimate(...)][matfree.hutchinson.estimate].
Same as in [mc_estimate(...)][matfree.hutchinson.mc_estimate].
num_batches:
Same as in [estimate(...)][matfree.hutchinson.estimate].
Same as in [mc_estimate(...)][matfree.hutchinson.mc_estimate].
num_samples_per_batch:
Same as in [estimate(...)][matfree.hutchinson.estimate].
Same as in [mc_estimate(...)][matfree.hutchinson.mc_estimate].
statistics_batch:
List or tuple of summary statistics to compute on batch-level.
statistics_combine:
Expand Down Expand Up @@ -145,7 +145,7 @@ def f_mean(key, /):
return f_mean


def normal(*, shape, dtype=float):
def sampler_normal(*, shape, dtype=float):
"""Construct a function that samples from a standard normal distribution."""

def fun(key):
Expand All @@ -154,7 +154,7 @@ def fun(key):
return fun


def rademacher(*, shape, dtype=float):
def sampler_rademacher(*, shape, dtype=float):
"""Construct a function that samples from a Rademacher distribution."""

def fun(key):
Expand All @@ -163,32 +163,6 @@ def fun(key):
return fun


class _VDCState(containers.NamedTuple):
n: int
vdc: float
denom: int


def van_der_corput(n, /, base=2):
"""Compute the 'n'th element of the Van-der-Corput sequence."""
state = _VDCState(n, vdc=0, denom=1)

vdc_modify = func.partial(_van_der_corput_modify, base=base)
state = control_flow.while_loop(_van_der_corput_cond, vdc_modify, state)
return state.vdc


def _van_der_corput_cond(state: _VDCState):
return state.n > 0


def _van_der_corput_modify(state: _VDCState, *, base):
denom = state.denom * base
num, remainder = divmod(state.n, base)
vdc = state.vdc + remainder / denom
return _VDCState(num, vdc, denom)


def trace(Av: Callable, /, **kwargs) -> Array:
"""Estimate the trace of a matrix stochastically.

Expand All @@ -198,13 +172,13 @@ def trace(Av: Callable, /, **kwargs) -> Array:
Matrix-vector product function.
**kwargs:
Keyword-arguments to be passed to
[estimate()][matfree.hutchinson.estimate].
[mc_estimate()][matfree.hutchinson.mc_estimate].
"""

def quadform(vec):
return linalg.vecdot(vec, Av(vec))

return estimate(quadform, **kwargs)
return mc_estimate(quadform, **kwargs)


def trace_moments(Av: Callable, /, moments: Sequence[int] = (1, 2), **kwargs) -> Array:
Expand All @@ -219,7 +193,7 @@ def trace_moments(Av: Callable, /, moments: Sequence[int] = (1, 2), **kwargs) ->
the first and second moment.
**kwargs:
Keyword-arguments to be passed to
[multiestimate(...)][matfree.hutchinson.multiestimate].
[mc_multiestimate(...)][matfree.hutchinson.mc_multiestimate].
"""

def quadform(vec):
Expand All @@ -230,7 +204,7 @@ def moment(x, axis, *, power):

statistics_batch = [func.partial(moment, power=m) for m in moments]
statistics_combine = [np.mean] * len(moments)
return multiestimate(
return mc_multiestimate(
quadform,
statistics_batch=statistics_batch,
statistics_combine=statistics_combine,
Expand All @@ -255,15 +229,15 @@ def frobeniusnorm_squared(Av: Callable, /, **kwargs) -> Array:
Matrix-vector product function.
**kwargs:
Keyword-arguments to be passed to
[estimate()][matfree.hutchinson.estimate].
[mc_estimate()][matfree.hutchinson.mc_estimate].

"""

def quadform(vec):
x = Av(vec)
return linalg.vecdot(x, x)

return estimate(quadform, **kwargs)
return mc_estimate(quadform, **kwargs)


def diagonal_with_control_variate(Av: Callable, control: Array, /, **kwargs) -> Array:
Expand All @@ -278,7 +252,7 @@ def diagonal_with_control_variate(Av: Callable, control: Array, /, **kwargs) ->
This should be the best-possible estimate of the diagonal of the matrix.
**kwargs:
Keyword-arguments to be passed to
[estimate()][matfree.hutchinson.estimate].
[mc_estimate()][matfree.hutchinson.mc_estimate].

"""
return diagonal(lambda v: Av(v) - control * v, **kwargs) + control
Expand All @@ -293,14 +267,14 @@ def diagonal(Av: Callable, /, **kwargs) -> Array:
Matrix-vector product function.
**kwargs:
Keyword-arguments to be passed to
[estimate()][matfree.hutchinson.estimate].
[mc_estimate()][matfree.hutchinson.mc_estimate].

"""

def quadform(vec):
return vec * Av(vec)

return estimate(quadform, **kwargs)
return mc_estimate(quadform, **kwargs)


def trace_and_diagonal(Av: Callable, /, *, sample_fun: Callable, key: Array, **kwargs):
Expand All @@ -318,8 +292,8 @@ def trace_and_diagonal(Av: Callable, /, *, sample_fun: Callable, key: Array, **k
Matrix-vector product function.
sample_fun:
Sampling function.
Usually, either [normal][matfree.hutchinson.normal]
or [rademacher][matfree.hutchinson.normal].
Usually, either [normal][matfree.hutchinson.sampler_normal]
or [rademacher][matfree.hutchinson.sampler_normal].
key:
Pseudo-random number generator key.
**kwargs:
Expand Down Expand Up @@ -368,8 +342,8 @@ def diagonal_multilevel(
Pseudo-random number generator key.
sample_fun:
Sampling function.
Usually, either [normal][matfree.hutchinson.normal]
or [rademacher][matfree.hutchinson.normal].
Usually, either [normal][matfree.hutchinson.sampler_normal]
or [rademacher][matfree.hutchinson.sampler_normal].
num_levels:
Number of levels.
num_batches_per_level:
Expand Down
4 changes: 2 additions & 2 deletions matfree/slq.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ def logdet_spd(*args, **kwargs):
def trace_of_matfun_spd(matfun, order, Av, /, **kwargs):
"""Compute the trace of the function of a symmetric matrix."""
quadratic_form = _quadratic_form_slq_spd(matfun, order, Av)
return hutchinson.estimate(quadratic_form, **kwargs)
return hutchinson.mc_estimate(quadratic_form, **kwargs)


def _quadratic_form_slq_spd(matfun, order, Av, /):
Expand Down Expand Up @@ -72,7 +72,7 @@ def trace_of_matfun_product(matfun, order, *matvec_funs, matrix_shape, **kwargs)
quadratic_form = _quadratic_form_slq_product(
matfun, order, *matvec_funs, matrix_shape=matrix_shape
)
return hutchinson.estimate(quadratic_form, **kwargs)
return hutchinson.mc_estimate(quadratic_form, **kwargs)


def _quadratic_form_slq_product(matfun, depth, *matvec_funs, matrix_shape):
Expand Down
4 changes: 3 additions & 1 deletion tests/test_hutchinson/test_diagonal.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,9 @@ def fixture_key():
@testing.parametrize("num_batches", [1_000])
@testing.parametrize("num_samples_per_batch", [1_000])
@testing.parametrize("dim", [1, 10])
@testing.parametrize("sample_fun", [hutchinson.normal, hutchinson.rademacher])
@testing.parametrize(
"sample_fun", [hutchinson.sampler_normal, hutchinson.sampler_rademacher]
)
def test_diagonal(fun, key, num_batches, num_samples_per_batch, dim, sample_fun):
"""Assert that the estimated diagonal approximates the true diagonal accurately."""
# Linearise function
Expand Down
4 changes: 2 additions & 2 deletions tests/test_hutchinson/test_estimate.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,11 @@ def test_mean(key, num_batches, num_samples):
def fun(x):
return x**2

received = hutchinson.estimate(
received = hutchinson.mc_estimate(
fun,
num_batches=num_batches,
num_samples_per_batch=num_samples,
key=key,
sample_fun=hutchinson.normal(shape=()),
sample_fun=hutchinson.sampler_normal(shape=()),
)
assert np.allclose(received, 1.0, rtol=1e-1)
4 changes: 3 additions & 1 deletion tests/test_hutchinson/test_frobeniusnorm_squared.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,9 @@ def fixture_key():
@testing.parametrize("num_batches", [1_000])
@testing.parametrize("num_samples_per_batch", [1_000])
@testing.parametrize("dim", [1, 10])
@testing.parametrize("sample_fun", [hutchinson.normal, hutchinson.rademacher])
@testing.parametrize(
"sample_fun", [hutchinson.sampler_normal, hutchinson.sampler_rademacher]
)
def test_frobeniusnorm_squared(
fun, key, num_batches, num_samples_per_batch, dim, sample_fun
):
Expand Down
Loading
Loading