diff --git a/README.md b/README.md index 600b275..b1c4f48 100644 --- a/README.md +++ b/README.md @@ -65,7 +65,7 @@ Estimate the trace of the matrix: ```python >>> key = jax.random.PRNGKey(1) ->>> normal = hutchinson.normal(shape=(2,)) +>>> normal = hutchinson.sampler_normal(shape=(2,)) >>> trace = hutchinson.trace(matvec, key=key, sample_fun=normal) >>> >>> print(jnp.round(trace)) diff --git a/docs/benchmarks/control_variates.py b/docs/benchmarks/control_variates.py index 0d3eeb6..5093292 100644 --- a/docs/benchmarks/control_variates.py +++ b/docs/benchmarks/control_variates.py @@ -22,7 +22,7 @@ def f(x): _, jvp = func.linearize(f, x0) J = func.jacfwd(f)(x0) trace = linalg.trace(J) - sample_fun = hutchinson.normal(shape=(n,), dtype=float) + sample_fun = hutchinson.sampler_normal(shape=(n,), dtype=float) return (jvp, trace, J), (key, sample_fun) diff --git a/docs/benchmarks/jacobian_squared.py b/docs/benchmarks/jacobian_squared.py index 5a253ea..082e05b 100644 --- a/docs/benchmarks/jacobian_squared.py +++ b/docs/benchmarks/jacobian_squared.py @@ -20,7 +20,7 @@ def f(x): J = func.jacfwd(f)(x0) A = J @ J @ J @ J trace = linalg.trace(A) - sample_fun = hutchinson.normal(shape=(n,), dtype=float) + sample_fun = hutchinson.sampler_normal(shape=(n,), dtype=float) def Av(v): return jvp(jvp(jvp(jvp(v)))) diff --git a/docs/control_variates.md b/docs/control_variates.md index 1f4230f..8507f2d 100644 --- a/docs/control_variates.md +++ b/docs/control_variates.md @@ -14,7 +14,7 @@ Imports: >>> key = jax.random.PRNGKey(1) >>> matvec = lambda x: a.T @ (a @ x) ->>> sample_fun = hutchinson.normal(shape=(2,)) +>>> sample_fun = hutchinson.sampler_normal(shape=(2,)) ``` diff --git a/docs/higher_moments.md b/docs/higher_moments.md index 9f74aa7..cea821f 100644 --- a/docs/higher_moments.md +++ b/docs/higher_moments.md @@ -9,7 +9,7 @@ >>> key = jax.random.PRNGKey(1) >>> mvp = lambda x: a.T @ (a @ x) ->>> sample_fun = hutchinson.normal(shape=(2,)) +>>> sample_fun = hutchinson.sampler_normal(shape=(2,)) ``` @@ -21,7 +21,7 @@ Compute them as such ```python >>> a = jnp.reshape(jnp.arange(36.0), (6, 6)) / 36 ->>> normal = hutchinson.normal(shape=(6,)) +>>> normal = hutchinson.sampler_normal(shape=(6,)) >>> mvp = lambda x: a.T @ (a @ x) + x >>> first, second = hutchinson.trace_moments(mvp, key=key, sample_fun=normal) >>> print(jnp.round(first, 1)) @@ -53,7 +53,7 @@ Implement this as follows: ```python >>> a = jnp.reshape(jnp.arange(36.0), (6, 6)) / 36 ->>> sample_fun = hutchinson.normal(shape=(6,)) +>>> sample_fun = hutchinson.sampler_normal(shape=(6,)) >>> num_samples = 10_000 >>> mvp = lambda x: a.T @ (a @ x) + x >>> first, second = hutchinson.trace_moments( diff --git a/docs/index.md b/docs/index.md index 600b275..b1c4f48 100644 --- a/docs/index.md +++ b/docs/index.md @@ -65,7 +65,7 @@ Estimate the trace of the matrix: ```python >>> key = jax.random.PRNGKey(1) ->>> normal = hutchinson.normal(shape=(2,)) +>>> normal = hutchinson.sampler_normal(shape=(2,)) >>> trace = hutchinson.trace(matvec, key=key, sample_fun=normal) >>> >>> print(jnp.round(trace)) diff --git a/docs/log_determinants.md b/docs/log_determinants.md index dd6375c..8495536 100644 --- a/docs/log_determinants.md +++ b/docs/log_determinants.md @@ -12,7 +12,7 @@ Imports: >>> key = jax.random.PRNGKey(1) >>> matvec = lambda x: a.T @ (a @ x) ->>> sample_fun = hutchinson.normal(shape=(2,)) +>>> sample_fun = hutchinson.sampler_normal(shape=(2,)) ``` @@ -20,7 +20,7 @@ Imports: Estimate log-determinants as such: ```python >>> a = jnp.reshape(jnp.arange(36.0), (6, 6)) / 36 ->>> sample_fun = hutchinson.normal(shape=(6,)) +>>> sample_fun = hutchinson.sampler_normal(shape=(6,)) >>> matvec = lambda x: a.T @ (a @ x) + x >>> order = 3 >>> logdet = slq.logdet_spd(order, matvec, key=key, sample_fun=sample_fun) @@ -37,7 +37,7 @@ on arithmetic with $B$; no need to assemble $M$: ```python >>> a = jnp.reshape(jnp.arange(36.0), (6, 6)) / 36 + jnp.eye(6) ->>> sample_fun = hutchinson.normal(shape=(6,)) +>>> sample_fun = hutchinson.sampler_normal(shape=(6,)) >>> matvec = lambda x: (a @ x) >>> vecmat = lambda x: (a.T @ x) >>> order = 3 diff --git a/docs/pytree_logdeterminants.md b/docs/pytree_logdeterminants.md index 4f09041..d25bf88 100644 --- a/docs/pytree_logdeterminants.md +++ b/docs/pytree_logdeterminants.md @@ -70,7 +70,7 @@ Now, we can compute the log-determinant with the flattened inputs as usual: ```python >>> # Compute the log-determinant >>> key = jax.random.PRNGKey(seed=1) ->>> sample_fun = hutchinson.normal(shape=f0_flat.shape) +>>> sample_fun = hutchinson.sampler_normal(shape=f0_flat.shape) >>> order = 3 >>> logdet = slq.logdet_spd(order, matvec, key=key, sample_fun=sample_fun) diff --git a/docs/vector_calculus.md b/docs/vector_calculus.md index 558b634..69bbc20 100644 --- a/docs/vector_calculus.md +++ b/docs/vector_calculus.md @@ -85,7 +85,7 @@ For large-scale problems, it may be the only way of computing Laplacians reliabl ```python >>> laplacian_dense = divergence_dense(gradient) >>> ->>> normal = hutchinson.normal(shape=(3,)) +>>> normal = hutchinson.sampler_normal(shape=(3,)) >>> key = jax.random.PRNGKey(1) >>> laplacian_matfree = divergence_matfree(gradient, key=key, sample_fun=normal) >>> diff --git a/matfree/hutchinson.py b/matfree/hutchinson.py index 26f4e6a..88281ee 100644 --- a/matfree/hutchinson.py +++ b/matfree/hutchinson.py @@ -9,7 +9,7 @@ # trace_and_frobeniusnorm(): y=Ax; return (x@y, y@y) -def estimate( +def mc_estimate( fun: Callable, /, *, @@ -31,8 +31,8 @@ def estimate( sample_fun: Sampling function. For trace-estimation, use - either [normal(...)][matfree.hutchinson.normal] - or [rademacher(...)][matfree.hutchinson.normal]. + either [normal(...)][matfree.hutchinson.sampler_normal] + or [rademacher(...)][matfree.hutchinson.sampler_normal]. num_batches: Number of batches when computing arithmetic means. num_samples_per_batch: @@ -50,7 +50,7 @@ def estimate( [one of these functions](https://data-apis.org/array-api/2022.12/API_specification/statistical_functions.html) would work. """ - [result] = multiestimate( + [result] = mc_multiestimate( fun, key=key, sample_fun=sample_fun, @@ -62,7 +62,7 @@ def estimate( return result -def multiestimate( +def mc_multiestimate( fun: Callable, /, *, @@ -76,7 +76,7 @@ def multiestimate( """Compute a Monte-Carlo estimate with multiple summary statistics. The signature of this function is almost identical to - [estimate(...)][matfree.hutchinson.estimate]. + [mc_estimate(...)][matfree.hutchinson.mc_estimate]. The only difference is that statistics_batch and statistics_combine are iterables of summary statistics (of equal lengths). @@ -85,15 +85,15 @@ def multiestimate( Parameters ---------- fun: - Same as in [estimate(...)][matfree.hutchinson.estimate]. + Same as in [mc_estimate(...)][matfree.hutchinson.mc_estimate]. key: - Same as in [estimate(...)][matfree.hutchinson.estimate]. + Same as in [mc_estimate(...)][matfree.hutchinson.mc_estimate]. sample_fun: - Same as in [estimate(...)][matfree.hutchinson.estimate]. + Same as in [mc_estimate(...)][matfree.hutchinson.mc_estimate]. num_batches: - Same as in [estimate(...)][matfree.hutchinson.estimate]. + Same as in [mc_estimate(...)][matfree.hutchinson.mc_estimate]. num_samples_per_batch: - Same as in [estimate(...)][matfree.hutchinson.estimate]. + Same as in [mc_estimate(...)][matfree.hutchinson.mc_estimate]. statistics_batch: List or tuple of summary statistics to compute on batch-level. statistics_combine: @@ -145,7 +145,7 @@ def f_mean(key, /): return f_mean -def normal(*, shape, dtype=float): +def sampler_normal(*, shape, dtype=float): """Construct a function that samples from a standard normal distribution.""" def fun(key): @@ -154,7 +154,7 @@ def fun(key): return fun -def rademacher(*, shape, dtype=float): +def sampler_rademacher(*, shape, dtype=float): """Construct a function that samples from a Rademacher distribution.""" def fun(key): @@ -163,32 +163,6 @@ def fun(key): return fun -class _VDCState(containers.NamedTuple): - n: int - vdc: float - denom: int - - -def van_der_corput(n, /, base=2): - """Compute the 'n'th element of the Van-der-Corput sequence.""" - state = _VDCState(n, vdc=0, denom=1) - - vdc_modify = func.partial(_van_der_corput_modify, base=base) - state = control_flow.while_loop(_van_der_corput_cond, vdc_modify, state) - return state.vdc - - -def _van_der_corput_cond(state: _VDCState): - return state.n > 0 - - -def _van_der_corput_modify(state: _VDCState, *, base): - denom = state.denom * base - num, remainder = divmod(state.n, base) - vdc = state.vdc + remainder / denom - return _VDCState(num, vdc, denom) - - def trace(Av: Callable, /, **kwargs) -> Array: """Estimate the trace of a matrix stochastically. @@ -198,13 +172,13 @@ def trace(Av: Callable, /, **kwargs) -> Array: Matrix-vector product function. **kwargs: Keyword-arguments to be passed to - [estimate()][matfree.hutchinson.estimate]. + [mc_estimate()][matfree.hutchinson.mc_estimate]. """ def quadform(vec): return linalg.vecdot(vec, Av(vec)) - return estimate(quadform, **kwargs) + return mc_estimate(quadform, **kwargs) def trace_moments(Av: Callable, /, moments: Sequence[int] = (1, 2), **kwargs) -> Array: @@ -219,7 +193,7 @@ def trace_moments(Av: Callable, /, moments: Sequence[int] = (1, 2), **kwargs) -> the first and second moment. **kwargs: Keyword-arguments to be passed to - [multiestimate(...)][matfree.hutchinson.multiestimate]. + [mc_multiestimate(...)][matfree.hutchinson.mc_multiestimate]. """ def quadform(vec): @@ -230,7 +204,7 @@ def moment(x, axis, *, power): statistics_batch = [func.partial(moment, power=m) for m in moments] statistics_combine = [np.mean] * len(moments) - return multiestimate( + return mc_multiestimate( quadform, statistics_batch=statistics_batch, statistics_combine=statistics_combine, @@ -255,7 +229,7 @@ def frobeniusnorm_squared(Av: Callable, /, **kwargs) -> Array: Matrix-vector product function. **kwargs: Keyword-arguments to be passed to - [estimate()][matfree.hutchinson.estimate]. + [mc_estimate()][matfree.hutchinson.mc_estimate]. """ @@ -263,7 +237,7 @@ def quadform(vec): x = Av(vec) return linalg.vecdot(x, x) - return estimate(quadform, **kwargs) + return mc_estimate(quadform, **kwargs) def diagonal_with_control_variate(Av: Callable, control: Array, /, **kwargs) -> Array: @@ -278,7 +252,7 @@ def diagonal_with_control_variate(Av: Callable, control: Array, /, **kwargs) -> This should be the best-possible estimate of the diagonal of the matrix. **kwargs: Keyword-arguments to be passed to - [estimate()][matfree.hutchinson.estimate]. + [mc_estimate()][matfree.hutchinson.mc_estimate]. """ return diagonal(lambda v: Av(v) - control * v, **kwargs) + control @@ -293,14 +267,14 @@ def diagonal(Av: Callable, /, **kwargs) -> Array: Matrix-vector product function. **kwargs: Keyword-arguments to be passed to - [estimate()][matfree.hutchinson.estimate]. + [mc_estimate()][matfree.hutchinson.mc_estimate]. """ def quadform(vec): return vec * Av(vec) - return estimate(quadform, **kwargs) + return mc_estimate(quadform, **kwargs) def trace_and_diagonal(Av: Callable, /, *, sample_fun: Callable, key: Array, **kwargs): @@ -318,8 +292,8 @@ def trace_and_diagonal(Av: Callable, /, *, sample_fun: Callable, key: Array, **k Matrix-vector product function. sample_fun: Sampling function. - Usually, either [normal][matfree.hutchinson.normal] - or [rademacher][matfree.hutchinson.normal]. + Usually, either [normal][matfree.hutchinson.sampler_normal] + or [rademacher][matfree.hutchinson.sampler_normal]. key: Pseudo-random number generator key. **kwargs: @@ -368,8 +342,8 @@ def diagonal_multilevel( Pseudo-random number generator key. sample_fun: Sampling function. - Usually, either [normal][matfree.hutchinson.normal] - or [rademacher][matfree.hutchinson.normal]. + Usually, either [normal][matfree.hutchinson.sampler_normal] + or [rademacher][matfree.hutchinson.sampler_normal]. num_levels: Number of levels. num_batches_per_level: diff --git a/matfree/slq.py b/matfree/slq.py index 35f0a0a..5575ac7 100644 --- a/matfree/slq.py +++ b/matfree/slq.py @@ -12,7 +12,7 @@ def logdet_spd(*args, **kwargs): def trace_of_matfun_spd(matfun, order, Av, /, **kwargs): """Compute the trace of the function of a symmetric matrix.""" quadratic_form = _quadratic_form_slq_spd(matfun, order, Av) - return hutchinson.estimate(quadratic_form, **kwargs) + return hutchinson.mc_estimate(quadratic_form, **kwargs) def _quadratic_form_slq_spd(matfun, order, Av, /): @@ -72,7 +72,7 @@ def trace_of_matfun_product(matfun, order, *matvec_funs, matrix_shape, **kwargs) quadratic_form = _quadratic_form_slq_product( matfun, order, *matvec_funs, matrix_shape=matrix_shape ) - return hutchinson.estimate(quadratic_form, **kwargs) + return hutchinson.mc_estimate(quadratic_form, **kwargs) def _quadratic_form_slq_product(matfun, depth, *matvec_funs, matrix_shape): diff --git a/tests/test_hutchinson/test_diagonal.py b/tests/test_hutchinson/test_diagonal.py index 5b39b04..e3729b3 100644 --- a/tests/test_hutchinson/test_diagonal.py +++ b/tests/test_hutchinson/test_diagonal.py @@ -23,7 +23,9 @@ def fixture_key(): @testing.parametrize("num_batches", [1_000]) @testing.parametrize("num_samples_per_batch", [1_000]) @testing.parametrize("dim", [1, 10]) -@testing.parametrize("sample_fun", [hutchinson.normal, hutchinson.rademacher]) +@testing.parametrize( + "sample_fun", [hutchinson.sampler_normal, hutchinson.sampler_rademacher] +) def test_diagonal(fun, key, num_batches, num_samples_per_batch, dim, sample_fun): """Assert that the estimated diagonal approximates the true diagonal accurately.""" # Linearise function diff --git a/tests/test_hutchinson/test_estimate.py b/tests/test_hutchinson/test_estimate.py index fe5c915..8bab963 100644 --- a/tests/test_hutchinson/test_estimate.py +++ b/tests/test_hutchinson/test_estimate.py @@ -12,11 +12,11 @@ def test_mean(key, num_batches, num_samples): def fun(x): return x**2 - received = hutchinson.estimate( + received = hutchinson.mc_estimate( fun, num_batches=num_batches, num_samples_per_batch=num_samples, key=key, - sample_fun=hutchinson.normal(shape=()), + sample_fun=hutchinson.sampler_normal(shape=()), ) assert np.allclose(received, 1.0, rtol=1e-1) diff --git a/tests/test_hutchinson/test_frobeniusnorm_squared.py b/tests/test_hutchinson/test_frobeniusnorm_squared.py index 626abe6..498b3ac 100644 --- a/tests/test_hutchinson/test_frobeniusnorm_squared.py +++ b/tests/test_hutchinson/test_frobeniusnorm_squared.py @@ -23,7 +23,9 @@ def fixture_key(): @testing.parametrize("num_batches", [1_000]) @testing.parametrize("num_samples_per_batch", [1_000]) @testing.parametrize("dim", [1, 10]) -@testing.parametrize("sample_fun", [hutchinson.normal, hutchinson.rademacher]) +@testing.parametrize( + "sample_fun", [hutchinson.sampler_normal, hutchinson.sampler_rademacher] +) def test_frobeniusnorm_squared( fun, key, num_batches, num_samples_per_batch, dim, sample_fun ): diff --git a/tests/test_hutchinson/test_multiestimate.py b/tests/test_hutchinson/test_multiestimate.py index b04609f..17668af 100644 --- a/tests/test_hutchinson/test_multiestimate.py +++ b/tests/test_hutchinson/test_multiestimate.py @@ -11,12 +11,12 @@ def test_mean_and_max(key, num_batches, num_samples): def fun(x): return x**2 - mean, amax = hutchinson.multiestimate( + mean, amax = hutchinson.mc_multiestimate( fun, num_batches=num_batches, num_samples_per_batch=num_samples, key=key, - sample_fun=hutchinson.normal(shape=()), + sample_fun=hutchinson.sampler_normal(shape=()), statistics_batch=[np.mean, np.array_max], statistics_combine=[np.mean, np.array_max], ) diff --git a/tests/test_hutchinson/test_trace.py b/tests/test_hutchinson/test_trace.py index 867cbc7..5ff4752 100644 --- a/tests/test_hutchinson/test_trace.py +++ b/tests/test_hutchinson/test_trace.py @@ -23,7 +23,9 @@ def fixture_key(): @testing.parametrize("num_batches", [1_000]) @testing.parametrize("num_samples_per_batch", [1_000]) @testing.parametrize("dim", [1, 10]) -@testing.parametrize("sample_fun", [hutchinson.normal, hutchinson.rademacher]) +@testing.parametrize( + "sample_fun", [hutchinson.sampler_normal, hutchinson.sampler_rademacher] +) def test_trace(fun, key, num_batches, num_samples_per_batch, dim, sample_fun): """Assert that the estimated trace approximates the true trace accurately.""" # Linearise function diff --git a/tests/test_hutchinson/test_trace_and_diagonal.py b/tests/test_hutchinson/test_trace_and_diagonal.py index f8f0391..30f2d06 100644 --- a/tests/test_hutchinson/test_trace_and_diagonal.py +++ b/tests/test_hutchinson/test_trace_and_diagonal.py @@ -22,7 +22,9 @@ def fixture_key(): @testing.parametrize("num_samples", [10_000]) @testing.parametrize("dim", [5]) -@testing.parametrize("sample_fun", [hutchinson.normal, hutchinson.rademacher]) +@testing.parametrize( + "sample_fun", [hutchinson.sampler_normal, hutchinson.sampler_rademacher] +) def test_trace_and_diagonal(fun, key, num_samples, dim, sample_fun): """Assert that the estimated trace and diagonal approximations are accurate.""" # Linearise function diff --git a/tests/test_hutchinson/test_trace_moments.py b/tests/test_hutchinson/test_trace_moments.py index 90ddbf3..21d195e 100644 --- a/tests/test_hutchinson/test_trace_moments.py +++ b/tests/test_hutchinson/test_trace_moments.py @@ -32,7 +32,7 @@ def test_variance_normal(J_and_jvp, key, num_batches, num_samples_per_batch, dim """Assert that the estimated trace approximates the true trace accurately.""" # Estimate the trace J, jvp = J_and_jvp - fun = hutchinson.normal(shape=(dim,), dtype=float) + fun = hutchinson.sampler_normal(shape=(dim,), dtype=float) first, second = hutchinson.trace_moments( jvp, key=key, @@ -58,7 +58,7 @@ def test_variance_rademacher(J_and_jvp, key, num_batches, num_samples_per_batch, """Assert that the estimated trace approximates the true trace accurately.""" # Estimate the trace J, jvp = J_and_jvp - fun = hutchinson.rademacher(shape=(dim,), dtype=float) + fun = hutchinson.sampler_rademacher(shape=(dim,), dtype=float) first, second = hutchinson.trace_moments( jvp, key=key, diff --git a/tests/test_hutchinson/test_van_der_corput.py b/tests/test_hutchinson/test_van_der_corput.py deleted file mode 100644 index 8890d3c..0000000 --- a/tests/test_hutchinson/test_van_der_corput.py +++ /dev/null @@ -1,15 +0,0 @@ -"""Tests for Monte-Carlo machinery.""" - -from matfree import hutchinson -from matfree.backend import np - - -def test_van_der_corput(): - """Assert that the van-der-Corput sequence yields values as expected.""" - expected = np.asarray([0, 0.5, 0.25, 0.75, 0.125, 0.625, 0.375, 0.875, 0.0625]) - received = np.asarray([hutchinson.van_der_corput(i) for i in range(9)]) - assert np.allclose(received, expected) - - expected = np.asarray([0.0, 1 / 3, 2 / 3, 1 / 9, 4 / 9, 7 / 9, 2 / 9, 5 / 9, 8 / 9]) - received = np.asarray([hutchinson.van_der_corput(i, base=3) for i in range(9)]) - assert np.allclose(received, expected) diff --git a/tests/test_slq/test_logdet_product.py b/tests/test_slq/test_logdet_product.py index f1c3df2..76b4386 100644 --- a/tests/test_slq/test_logdet_product.py +++ b/tests/test_slq/test_logdet_product.py @@ -22,7 +22,7 @@ def test_logdet_product(A, order): """Assert that logdet_product yields an accurate estimate.""" _, ncols = np.shape(A) key = prng.prng_key(3) - fun = hutchinson.normal(shape=(ncols,)) + fun = hutchinson.sampler_normal(shape=(ncols,)) received = slq.logdet_product( order, lambda v: A @ v, diff --git a/tests/test_slq/test_logdet_spd.py b/tests/test_slq/test_logdet_spd.py index cfcfb60..4c8b0b2 100644 --- a/tests/test_slq/test_logdet_spd.py +++ b/tests/test_slq/test_logdet_spd.py @@ -23,7 +23,7 @@ def test_logdet_spd(A, order): """Assert that the log-determinant estimation matches the true log-determinant.""" n, _ = np.shape(A) key = prng.prng_key(1) - fun = hutchinson.normal(shape=(n,)) + fun = hutchinson.sampler_normal(shape=(n,)) received = slq.logdet_spd( order, lambda v: A @ v, diff --git a/tests/test_slq/test_logdet_spd_autodiff.py b/tests/test_slq/test_logdet_spd_autodiff.py index ca4b063..b8f0b78 100644 --- a/tests/test_slq/test_logdet_spd_autodiff.py +++ b/tests/test_slq/test_logdet_spd_autodiff.py @@ -32,7 +32,7 @@ def fun(s): def _logdet(A, order, key): n, _ = np.shape(A) - fun = hutchinson.normal(shape=(n,)) + fun = hutchinson.sampler_normal(shape=(n,)) return slq.logdet_spd( order, lambda v: A @ v, diff --git a/tests/test_slq/test_schatten_norm.py b/tests/test_slq/test_schatten_norm.py index f8e3428..f9a208d 100644 --- a/tests/test_slq/test_schatten_norm.py +++ b/tests/test_slq/test_schatten_norm.py @@ -26,7 +26,7 @@ def test_schatten_norm(A, order, power): _, ncols = np.shape(A) key = prng.prng_key(1) - fun = hutchinson.normal(shape=(ncols,)) + fun = hutchinson.sampler_normal(shape=(ncols,)) received = slq.schatten_norm( order, lambda v: A @ v,