From 835a1d4b5b58618aec277a6257a2d0acbbea909f Mon Sep 17 00:00:00 2001 From: Felix Koehler Date: Wed, 4 Sep 2024 13:50:10 +0200 Subject: [PATCH 01/25] Preliminary function to compute the spectrum --- exponax/_spectral.py | 45 ++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 45 insertions(+) diff --git a/exponax/_spectral.py b/exponax/_spectral.py index 042362e..3106be6 100644 --- a/exponax/_spectral.py +++ b/exponax/_spectral.py @@ -1,10 +1,13 @@ from itertools import product from typing import Optional, TypeVar, Union +import jax import jax.numpy as jnp from jaxtyping import Array, Bool, Complex, Float +C = TypeVar("C") D = TypeVar("D") +N = TypeVar("N") def build_wavenumbers( @@ -652,3 +655,45 @@ def make_incompressible( ) return incompressible_field + + +def get_power_spectrum( + field: Float[Array, "C ... N"], +) -> Float[Array, "C (N//2)+1"]: + """ + Preliminary function -> might not be working correctly ... :/ + + Inspired by: + https://github.com/scaomath/torch-cfd/blob/8c64319272f7660a57c491d823384130823900fe/sfno/visualizations.py#L114 + + """ + num_spatial_dims = field.ndim - 1 + num_points = field.shape[-1] + + field_hat = fft(field, num_spatial_dims=num_spatial_dims) + + wavenumbers_mesh = build_wavenumbers(num_spatial_dims, num_points) + wavenumbers_1d = build_wavenumbers(1, num_points) + wavenumbers_norm = jnp.linalg.norm(wavenumbers_mesh, axis=0, keepdims=True) + + dk = wavenumbers_1d[0, 1] - wavenumbers_1d[0, 0] + + power = jnp.abs(field_hat) ** 2 * 1 / 2 + + spectrum = [] + + def power_in_bucket(p, k): + lower_limit = k - dk / 2 + upper_limit = k + dk / 2 + mask = (wavenumbers_norm[0] >= lower_limit) & ( + wavenumbers_norm[0] < upper_limit + ) + return jnp.sum(p[mask]) + + for k in wavenumbers_1d[0, :]: + spectrum.append(jax.vmap(power_in_bucket, in_axes=(0, None))(power, k)) + + spectrum = jnp.stack(spectrum, axis=-1) + spectrum /= jnp.sum(spectrum, axis=-1, keepdims=True) + + return spectrum From 43f2c35f7c42d5ebbdcb5d0b59ed6279dffd1b68 Mon Sep 17 00:00:00 2001 From: Felix Koehler Date: Wed, 4 Sep 2024 14:03:15 +0200 Subject: [PATCH 02/25] Potential improvement to spectrum computation --- exponax/_spectral.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/exponax/_spectral.py b/exponax/_spectral.py index 3106be6..61ea0c0 100644 --- a/exponax/_spectral.py +++ b/exponax/_spectral.py @@ -688,12 +688,12 @@ def power_in_bucket(p, k): mask = (wavenumbers_norm[0] >= lower_limit) & ( wavenumbers_norm[0] < upper_limit ) - return jnp.sum(p[mask]) + return jnp.mean(p[mask]) for k in wavenumbers_1d[0, :]: spectrum.append(jax.vmap(power_in_bucket, in_axes=(0, None))(power, k)) spectrum = jnp.stack(spectrum, axis=-1) - spectrum /= jnp.sum(spectrum, axis=-1, keepdims=True) + # spectrum /= jnp.sum(spectrum, axis=-1, keepdims=True) return spectrum From ea7238b70c1d62f69688799add3b52de5031c259 Mon Sep 17 00:00:00 2001 From: Felix Koehler Date: Wed, 4 Sep 2024 14:12:07 +0200 Subject: [PATCH 03/25] Add hint on enstrophy --- exponax/_spectral.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/exponax/_spectral.py b/exponax/_spectral.py index 61ea0c0..84df3ef 100644 --- a/exponax/_spectral.py +++ b/exponax/_spectral.py @@ -663,6 +663,10 @@ def get_power_spectrum( """ Preliminary function -> might not be working correctly ... :/ + If passed a vorticity field like produced by + `exponax.stepper.NavierStokesVorticity`, this will produce the enstrophy + spectrum. + Inspired by: https://github.com/scaomath/torch-cfd/blob/8c64319272f7660a57c491d823384130823900fe/sfno/visualizations.py#L114 From d48f50432e49d6269be9e003cb04b0c8a9ffd0dc Mon Sep 17 00:00:00 2001 From: Felix Koehler Date: Thu, 5 Sep 2024 07:41:00 +0200 Subject: [PATCH 04/25] Fix on wavenumber norm computation --- exponax/_spectral.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/exponax/_spectral.py b/exponax/_spectral.py index 84df3ef..1773be2 100644 --- a/exponax/_spectral.py +++ b/exponax/_spectral.py @@ -256,7 +256,7 @@ def low_pass_filter_mask( for wn_grid in wavenumbers: mask = mask & (jnp.abs(wn_grid) <= cutoff) else: - mask = jnp.linalg.norm(mask, axis=0) <= cutoff + mask = jnp.linalg.norm(wavenumbers, axis=0) <= cutoff mask = mask[jnp.newaxis, ...] From c50469598e8fe22854334a1e5cbab0fb0baae9a5 Mon Sep 17 00:00:00 2001 From: Felix Koehler Date: Thu, 5 Sep 2024 08:21:43 +0200 Subject: [PATCH 05/25] Rename to better reflect its effect on oddball mode --- exponax/_interpolation.py | 6 +++--- exponax/_spectral.py | 13 +++++++++---- 2 files changed, 12 insertions(+), 7 deletions(-) diff --git a/exponax/_interpolation.py b/exponax/_interpolation.py index 412a46a..cdf612f 100644 --- a/exponax/_interpolation.py +++ b/exponax/_interpolation.py @@ -15,7 +15,7 @@ fft, get_modes_slices, ifft, - nyquist_filter_mask, + oddball_filter_mask, space_indices, wavenumber_shape, ) @@ -247,7 +247,7 @@ def map_between_resolutions( if new_num_points > old_num_points: # Upscaling if old_num_points % 2 == 0 and oddball_zero: - old_state_hat_scaled *= nyquist_filter_mask( + old_state_hat_scaled *= oddball_filter_mask( num_spatial_dims, old_num_points ) @@ -273,7 +273,7 @@ def map_between_resolutions( if old_num_points > new_num_points: # Downscaling if new_num_points % 2 == 0 and oddball_zero: - new_state_hat *= nyquist_filter_mask(num_spatial_dims, new_num_points) + new_state_hat *= oddball_filter_mask(num_spatial_dims, new_num_points) new_state = ifft( new_state_hat, diff --git a/exponax/_spectral.py b/exponax/_spectral.py index 1773be2..c53567b 100644 --- a/exponax/_spectral.py +++ b/exponax/_spectral.py @@ -263,20 +263,25 @@ def low_pass_filter_mask( return mask -def nyquist_filter_mask( +def oddball_filter_mask( num_spatial_dims: int, num_points: int, ) -> Bool[Array, "1 ... N"]: """ - Creates mask that if multiplied with a field in Fourier space will remove - the Nyquist mode. + Creates mask that if multiplied with a field in Fourier space remove the + Nyquist mode if the number of degrees of freedom is even. **Arguments:** - `num_spatial_dims`: The number of spatial dimensions. - `num_points`: The number of points in each spatial dimension. **Returns:** - - `mask`: The Nyquist filter mask, shape `(1, ..., N//2+1)`. + - `mask`: The oddball filter mask, shape `(1, ..., N//2+1)`. + + !!! info + For more background on why this is needed, see + https://www.mech.kth.se/~mattias/simson-user-guide-v4.0.pdf section + 6.2.4 and https://math.mit.edu/~stevenj/fft-deriv.pdf """ if num_points % 2 == 1: # Odd number of degrees of freedom (no issue with the Nyquist mode) From 82476f5c2f2c63c22e7b9d6b81dca901d35881f4 Mon Sep 17 00:00:00 2001 From: Felix Koehler Date: Thu, 5 Sep 2024 08:22:25 +0200 Subject: [PATCH 06/25] Add tests on filter mask --- tests/test_filter_masks.py | 451 +++++++++++++++++++++++++++++++++++++ 1 file changed, 451 insertions(+) create mode 100644 tests/test_filter_masks.py diff --git a/tests/test_filter_masks.py b/tests/test_filter_masks.py new file mode 100644 index 0000000..2818566 --- /dev/null +++ b/tests/test_filter_masks.py @@ -0,0 +1,451 @@ +import numpy as np + +import exponax as ex + + +def test_low_pass_filter_masks_1d(): + # Need to test both for even and odd number of points because that changes + # how the Nyquist mode is treated. + np.testing.assert_equal( + ex.spectral.low_pass_filter_mask(1, 10, cutoff=3), + np.array([[True, True, True, True, False, False]]), + ) + np.testing.assert_equal( + ex.spectral.low_pass_filter_mask(1, 11, cutoff=3), + np.array([[True, True, True, True, False, False]]), + ) + + +def test_nyquist_filter_masks_1d(): + np.testing.assert_equal( + ex.spectral.oddball_filter_mask(1, 10), + np.array([[True, True, True, True, True, False]]), + ) + np.testing.assert_equal( + ex.spectral.oddball_filter_mask(1, 11), + np.array([[True, True, True, True, True, True]]), + ) + + +def test_low_pass_filter_masks_2d(): + np.testing.assert_equal( + ex.spectral.low_pass_filter_mask(2, 10, cutoff=3), + np.array( + [ + [ + [True, True, True, True, False, False], + [True, True, True, True, False, False], + [True, True, True, True, False, False], + [True, True, True, True, False, False], + [False, False, False, False, False, False], + [False, False, False, False, False, False], + [False, False, False, False, False, False], + [True, True, True, True, False, False], + [True, True, True, True, False, False], + [True, True, True, True, False, False], + ] + ] + ), + ) + np.testing.assert_equal( + ex.spectral.low_pass_filter_mask(2, 11, cutoff=3), + np.array( + [ + [ + [True, True, True, True, False, False], + [True, True, True, True, False, False], + [True, True, True, True, False, False], + [True, True, True, True, False, False], + [False, False, False, False, False, False], + [False, False, False, False, False, False], + [False, False, False, False, False, False], + [False, False, False, False, False, False], + [True, True, True, True, False, False], + [True, True, True, True, False, False], + [True, True, True, True, False, False], + ] + ] + ), + ) + # Below is with `axis_separate=False` which not creates `True`-hypercube + # regions but spheres (in 3d) or circles (in 2d) of `True` values. + np.testing.assert_equal( + ex.spectral.low_pass_filter_mask(2, 10, cutoff=3, axis_separate=False), + np.array( + [ + [ + [True, True, True, True, False, False], + [True, True, True, False, False, False], + [True, True, True, False, False, False], + [True, False, False, False, False, False], + [False, False, False, False, False, False], + [False, False, False, False, False, False], + [False, False, False, False, False, False], + [True, False, False, False, False, False], + [True, True, True, False, False, False], + [True, True, True, False, False, False], + ] + ] + ), + ) + np.testing.assert_equal( + ex.spectral.low_pass_filter_mask(2, 11, cutoff=3, axis_separate=False), + np.array( + [ + [ + [True, True, True, True, False, False], + [True, True, True, False, False, False], + [True, True, True, False, False, False], + [True, False, False, False, False, False], + [False, False, False, False, False, False], + [False, False, False, False, False, False], + [False, False, False, False, False, False], + [False, False, False, False, False, False], + [True, False, False, False, False, False], + [True, True, True, False, False, False], + [True, True, True, False, False, False], + ] + ] + ), + ) + + +def test_nyquist_filter_masks_2d(): + np.testing.assert_equal( + ex.spectral.oddball_filter_mask(2, 10), + np.array( + [ + [ + [True, True, True, True, True, False], + [True, True, True, True, True, False], + [True, True, True, True, True, False], + [True, True, True, True, True, False], + [True, True, True, True, True, False], + [False, False, False, False, False, False], + [True, True, True, True, True, False], + [True, True, True, True, True, False], + [True, True, True, True, True, False], + [True, True, True, True, True, False], + ] + ] + ), + ) + np.testing.assert_equal( + ex.spectral.oddball_filter_mask(2, 11), + np.array( + [ + [ + [True, True, True, True, True, True], + [True, True, True, True, True, True], + [True, True, True, True, True, True], + [True, True, True, True, True, True], + [True, True, True, True, True, True], + [True, True, True, True, True, True], + [True, True, True, True, True, True], + [True, True, True, True, True, True], + [True, True, True, True, True, True], + [True, True, True, True, True, True], + [True, True, True, True, True, True], + ] + ] + ), + ) + + +def test_low_pass_filter_masks_3d(): + np.testing.assert_equal( + ex.spectral.low_pass_filter_mask(3, 8, cutoff=2), + np.array( + [ + [ + [ + [True, True, True, False, False], + [True, True, True, False, False], + [True, True, True, False, False], + [False, False, False, False, False], + [False, False, False, False, False], + [False, False, False, False, False], + [True, True, True, False, False], + [True, True, True, False, False], + ], + [ + [True, True, True, False, False], + [True, True, True, False, False], + [True, True, True, False, False], + [False, False, False, False, False], + [False, False, False, False, False], + [False, False, False, False, False], + [True, True, True, False, False], + [True, True, True, False, False], + ], + [ + [True, True, True, False, False], + [True, True, True, False, False], + [True, True, True, False, False], + [False, False, False, False, False], + [False, False, False, False, False], + [False, False, False, False, False], + [True, True, True, False, False], + [True, True, True, False, False], + ], + [ + [False, False, False, False, False], + [False, False, False, False, False], + [False, False, False, False, False], + [False, False, False, False, False], + [False, False, False, False, False], + [False, False, False, False, False], + [False, False, False, False, False], + [False, False, False, False, False], + ], + [ + [False, False, False, False, False], + [False, False, False, False, False], + [False, False, False, False, False], + [False, False, False, False, False], + [False, False, False, False, False], + [False, False, False, False, False], + [False, False, False, False, False], + [False, False, False, False, False], + ], + [ + [False, False, False, False, False], + [False, False, False, False, False], + [False, False, False, False, False], + [False, False, False, False, False], + [False, False, False, False, False], + [False, False, False, False, False], + [False, False, False, False, False], + [False, False, False, False, False], + ], + [ + [True, True, True, False, False], + [True, True, True, False, False], + [True, True, True, False, False], + [False, False, False, False, False], + [False, False, False, False, False], + [False, False, False, False, False], + [True, True, True, False, False], + [True, True, True, False, False], + ], + [ + [True, True, True, False, False], + [True, True, True, False, False], + [True, True, True, False, False], + [False, False, False, False, False], + [False, False, False, False, False], + [False, False, False, False, False], + [True, True, True, False, False], + [True, True, True, False, False], + ], + ] + ] + ), + ) + np.testing.assert_equal( + ex.spectral.low_pass_filter_mask(3, 9, cutoff=2), + np.array( + [ + [ + [ + [True, True, True, False, False], + [True, True, True, False, False], + [True, True, True, False, False], + [False, False, False, False, False], + [False, False, False, False, False], + [False, False, False, False, False], + [False, False, False, False, False], + [True, True, True, False, False], + [True, True, True, False, False], + ], + [ + [True, True, True, False, False], + [True, True, True, False, False], + [True, True, True, False, False], + [False, False, False, False, False], + [False, False, False, False, False], + [False, False, False, False, False], + [False, False, False, False, False], + [True, True, True, False, False], + [True, True, True, False, False], + ], + [ + [True, True, True, False, False], + [True, True, True, False, False], + [True, True, True, False, False], + [False, False, False, False, False], + [False, False, False, False, False], + [False, False, False, False, False], + [False, False, False, False, False], + [True, True, True, False, False], + [True, True, True, False, False], + ], + [ + [False, False, False, False, False], + [False, False, False, False, False], + [False, False, False, False, False], + [False, False, False, False, False], + [False, False, False, False, False], + [False, False, False, False, False], + [False, False, False, False, False], + [False, False, False, False, False], + [False, False, False, False, False], + ], + [ + [False, False, False, False, False], + [False, False, False, False, False], + [False, False, False, False, False], + [False, False, False, False, False], + [False, False, False, False, False], + [False, False, False, False, False], + [False, False, False, False, False], + [False, False, False, False, False], + [False, False, False, False, False], + ], + [ + [False, False, False, False, False], + [False, False, False, False, False], + [False, False, False, False, False], + [False, False, False, False, False], + [False, False, False, False, False], + [False, False, False, False, False], + [False, False, False, False, False], + [False, False, False, False, False], + [False, False, False, False, False], + ], + [ + [False, False, False, False, False], + [False, False, False, False, False], + [False, False, False, False, False], + [False, False, False, False, False], + [False, False, False, False, False], + [False, False, False, False, False], + [False, False, False, False, False], + [False, False, False, False, False], + [False, False, False, False, False], + ], + [ + [True, True, True, False, False], + [True, True, True, False, False], + [True, True, True, False, False], + [False, False, False, False, False], + [False, False, False, False, False], + [False, False, False, False, False], + [False, False, False, False, False], + [True, True, True, False, False], + [True, True, True, False, False], + ], + [ + [True, True, True, False, False], + [True, True, True, False, False], + [True, True, True, False, False], + [False, False, False, False, False], + [False, False, False, False, False], + [False, False, False, False, False], + [False, False, False, False, False], + [True, True, True, False, False], + [True, True, True, False, False], + ], + ] + ] + ), + ) + + # TODO: Add tests for `axis_separate=False` in 3D. + + +def test_nyquist_filter_masks_3d(): + np.testing.assert_equal( + ex.spectral.oddball_filter_mask(3, 8), + np.array( + [ + [ + [ + [True, True, True, True, False], + [True, True, True, True, False], + [True, True, True, True, False], + [True, True, True, True, False], + [False, False, False, False, False], + [True, True, True, True, False], + [True, True, True, True, False], + [True, True, True, True, False], + ], + [ + [True, True, True, True, False], + [True, True, True, True, False], + [True, True, True, True, False], + [True, True, True, True, False], + [False, False, False, False, False], + [True, True, True, True, False], + [True, True, True, True, False], + [True, True, True, True, False], + ], + [ + [True, True, True, True, False], + [True, True, True, True, False], + [True, True, True, True, False], + [True, True, True, True, False], + [False, False, False, False, False], + [True, True, True, True, False], + [True, True, True, True, False], + [True, True, True, True, False], + ], + [ + [True, True, True, True, False], + [True, True, True, True, False], + [True, True, True, True, False], + [True, True, True, True, False], + [False, False, False, False, False], + [True, True, True, True, False], + [True, True, True, True, False], + [True, True, True, True, False], + ], + [ + [False, False, False, False, False], + [False, False, False, False, False], + [False, False, False, False, False], + [False, False, False, False, False], + [False, False, False, False, False], + [False, False, False, False, False], + [False, False, False, False, False], + [False, False, False, False, False], + ], + [ + [True, True, True, True, False], + [True, True, True, True, False], + [True, True, True, True, False], + [True, True, True, True, False], + [False, False, False, False, False], + [True, True, True, True, False], + [True, True, True, True, False], + [True, True, True, True, False], + ], + [ + [True, True, True, True, False], + [True, True, True, True, False], + [True, True, True, True, False], + [True, True, True, True, False], + [False, False, False, False, False], + [True, True, True, True, False], + [True, True, True, True, False], + [True, True, True, True, False], + ], + [ + [True, True, True, True, False], + [True, True, True, True, False], + [True, True, True, True, False], + [True, True, True, True, False], + [False, False, False, False, False], + [True, True, True, True, False], + [True, True, True, True, False], + [True, True, True, True, False], + ], + ] + ] + ), + ) + + np.testing.assert_equal( + ex.spectral.oddball_filter_mask(3, 9), + np.ones((1, 9, 9, (9 // 2) + 1), dtype=bool), + ) From b6f7ecd1e5e9eea3128f19e49815b5d01900d1ad Mon Sep 17 00:00:00 2001 From: Felix Koehler Date: Thu, 5 Sep 2024 08:40:17 +0200 Subject: [PATCH 07/25] Improve first half of spectral documentation --- exponax/_spectral.py | 291 +++++++++++++++++++++++++++++-------------- 1 file changed, 199 insertions(+), 92 deletions(-) diff --git a/exponax/_spectral.py b/exponax/_spectral.py index c53567b..fdb876d 100644 --- a/exponax/_spectral.py +++ b/exponax/_spectral.py @@ -22,14 +22,16 @@ def build_wavenumbers( `jax.numpy.fft.rfftn`. **Arguments:** - - `num_spatial_dims`: The number of spatial dimensions. - - `num_points`: The number of points in each spatial dimension. - - `indexing`: The indexing scheme to use for `jax.numpy.meshgrid`. - Either `"ij"` or `"xy"`. Default is `"ij"`. + + - `num_spatial_dims`: The number of spatial dimensions. + - `num_points`: The number of points in each spatial dimension. + - `indexing`: The indexing scheme to use for `jax.numpy.meshgrid`. + Either `"ij"` or `"xy"`. Default is `"ij"`. **Returns:** - - `wavenumbers`: An array of wavenumber integer coordinates, shape - `(D, ..., (N//2)+1)`. + + - `wavenumbers`: An array of wavenumber integer coordinates, shape + `(D, ..., (N//2)+1)`. """ right_most_wavenumbers = jnp.fft.rfftfreq(num_points, 1 / num_points) other_wavenumbers = jnp.fft.fftfreq(num_points, 1 / num_points) @@ -55,20 +57,26 @@ def build_scaled_wavenumbers( indexing: str = "ij", ) -> Float[Array, "D ... (N//2)+1"]: """ - Setup an array containing scaled wavenumbers associated with a - "num_spatial_dims"-dimensional rfft (real-valued FFT) - `jax.numpy.fft.rfftn`. Scaling is done by `2 * pi / L`. + Setup an array containing **scaled** wavenumbers associated with a + "num_spatial_dims"-dimensional rfft (real-valued FFT) `jax.numpy.fft.rfftn`. + Scaling is done by `2 * pi / L`. **Arguments:** - - `num_spatial_dims`: The number of spatial dimensions. - - `domain_extent`: The domain extent. - - `num_points`: The number of points in each spatial dimension. - - `indexing`: The indexing scheme to use for `jax.numpy.meshgrid`. - Either `"ij"` or `"xy"`. Default is `"ij"`. + + - `num_spatial_dims`: The number of spatial dimensions. + - `domain_extent`: The domain extent. + - `num_points`: The number of points in each spatial dimension. + - `indexing`: The indexing scheme to use for `jax.numpy.meshgrid`. + Either `"ij"` or `"xy"`. Default is `"ij"`. **Returns:** - - `wavenumbers`: An array of wavenumber integer coordinates, shape - `(D, ..., (N//2)+1)`. + + - `wavenumbers`: An array of wavenumber integer coordinates, shape + `(D, ..., (N//2)+1)`. + + !!! info + These correctly scaled wavenumbers are used to set up derivative + operators via `1j * wavenumbers`. """ scale = 2 * jnp.pi / domain_extent wavenumbers = build_wavenumbers(num_spatial_dims, num_points, indexing=indexing) @@ -86,15 +94,21 @@ def build_derivative_operator( Setup the derivative operator in Fourier space. **Arguments:** - - `D`: The number of spatial dimensions. - - `L`: The domain extent. - - `N`: The number of points in each spatial dimension. - - `indexing`: The indexing scheme to use for `jax.numpy.meshgrid`. - Either `"ij"` or `"xy"`. Default is `"ij"`. + + - `num_spatial_dims`: The number of spatial dimensions `d`. + - `domain_extent`: The size of the domain `L`; in higher dimensions + the domain is assumed to be a scaled hypercube `Ω = (0, L)ᵈ`. + - `num_points`: The number of points `N` used to discretize the + domain. This **includes** the left boundary point and **excludes** the + right boundary point. In higher dimensions; the number of points in each + dimension is the same. Hence, the total number of degrees of freedom is + `Nᵈ`. + - `indexing`: The indexing scheme to use for `jax.numpy.meshgrid`. **Returns:** - - `derivative_operator`: The derivative operator, shape `(D, ..., - N//2+1)`. + + - `derivative_operator`: The derivative operator in Fourier space + (complex-valued array) """ return 1j * build_scaled_wavenumbers( num_spatial_dims, domain_extent, num_points, indexing=indexing @@ -107,16 +121,27 @@ def build_laplace_operator( order: int = 2, ) -> Complex[Array, "1 ... (N//2)+1"]: """ - Given the derivative operator of [`build_derivative_operator`], return the - Laplace operator. + Given the derivative operator of + [`exponax.spectral.build_derivative_operator`], return the Laplace operator. + + In state space: + + Δ = ∇ ⋅ ∇ + + And in Fourier space: + + i² k⃗ᵀ k⃗ = - k⃗ᵀ k⃗ **Arguments:** - - `derivative_operator`: The derivative operator, shape `(D, ..., - N//2+1)`. - - `order`: The order of the Laplace operator. Default is `2`. + + - `derivative_operator`: The derivative operator in Fourier space. + - `order`: The order of the Laplace operator. Default is `2`. Use a higher + even number for "higher-order Laplacians". For example, `order=4` will + return the biharmonic operator (without spatial mixing). **Returns:** - - `laplace_operator`: The Laplace operator, shape `(1, ..., N//2+1)`. + + - `laplace_operator`: The Laplace operator in Fourier space. """ if order % 2 != 0: raise ValueError("Order must be even.") @@ -131,18 +156,31 @@ def build_gradient_inner_product_operator( order: int = 1, ) -> Complex[Array, "1 ... (N//2)+1"]: """ - Given the derivative operator of [`build_derivative_operator`] and a velocity - field, return the operator that computes the inner product of the gradient - with the velocity. + Given the derivative operator of [`build_derivative_operator`] and a + velocity vector, return the operator that computes the inner product of the + gradient with the velocity. + + In state space this is: + + c⃗ ⋅ ∇ + + And in Fourier space: + + c⃗ ⋅ i k⃗ **Arguments:** - - `derivative_operator`: The derivative operator, shape `(D, ..., - N//2+1)`. - - `velocity`: The velocity field, shape `(D,)`. - - `order`: The order of the gradient. Default is `1`. + + - `derivative_operator`: The derivative operator in Fourier space. + - `velocity`: The velocity vector, must be an array with one axis with as + many dimensions as the derivative operator has in its leading axis. + - `order`: The order of the gradient. Default is `1` which is the "regular + gradient". Use higher orders for higher-order gradients given in terms + elementwise products. For example, `order=3` will return `c⃗ ⋅ (∇ ⊙ ∇ ⊙ + ∇)` **Returns:** - - `operator`: The operator, shape `(1, ..., N//2+1)`. + + - `operator`: The operator in Fourier space. """ if order % 2 != 1: raise ValueError("Order must be odd.") @@ -161,52 +199,47 @@ def build_gradient_inner_product_operator( # Need to add singleton channel axis operator = operator[None, ...] - # Old form below - # # Need to move the channel/dimension axis last to enable autobroadcast over - # # the arbitrary number of spatial axes, Then we can move this singleton axis - # # back to the front - # operator = jnp.swapaxes( - # jnp.sum( - # velocity - # * jnp.swapaxes( - # derivative_operator**order, - # 0, - # -1, - # ), - # axis=-1, - # keepdims=True, - # ), - # 0, - # -1, - # ) - return operator def space_indices(num_spatial_dims: int) -> tuple[int, ...]: """ - Returns the indices within a field array that correspond to the spatial - dimensions. + Returns the axes indices within a state array that correspond to the spatial + axes. + + !!! example + For a 2D field array, the spatial indices are `(-2, -1)`. **Arguments:** - - `num_spatial_dims`: The number of spatial dimensions. + + - `num_spatial_dims`: The number of spatial dimensions. **Returns:** - - `indices`: The indices of the spatial dimensions. + + - `indices`: The indices of the spatial axes. """ return tuple(range(-num_spatial_dims, 0)) def spatial_shape(num_spatial_dims: int, num_points: int) -> tuple[int, ...]: """ - Returns the shape of a spatial field array. + Returns the shape of a spatial field array (without its leading channel + axis). This follows the `Exponax` convention that the resolution is + indentical in each dimension. + + !!! example + For a 2D field array with 64 points in each dimension, the spatial shape + is `(64, 64)`. For a 3D field array with 32 points in each dimension, + the spatial shape is `(32, 32, 32)`. **Arguments:** - - `num_spatial_dims`: The number of spatial dimensions. - - `num_points`: The number of points in each spatial dimension. + + - `num_spatial_dims`: The number of spatial dimensions. + - `num_points`: The number of points in each spatial dimension. **Returns:** - - `shape`: The shape of the spatial field array. + + - `shape`: The shape of the spatial field array. """ return (num_points,) * num_spatial_dims @@ -214,14 +247,23 @@ def spatial_shape(num_spatial_dims: int, num_points: int) -> tuple[int, ...]: def wavenumber_shape(num_spatial_dims: int, num_points: int) -> tuple[int, ...]: """ Returns the spatial shape of a field in Fourier space (assuming the usage of - rfft, `jax.numpy.fft.rfftn`). + `exponax.fft` which internall performs a real-valued fft + `jax.numpy.fft.rfftn`). + + !!! example + For a 2D field array with 64 points in each dimension, the wavenumber shape + is `(64, 33)`. For a 3D field array with 32 points in each dimension, + the spatial shape is `(32, 32, 17)`. For a 1D field array with 51 points, + the wavenumber shape is `(26,)`. **Arguments:** - - `num_spatial_dims`: The number of spatial dimensions. - - `num_points`: The number of points in each spatial dimension. + + - `num_spatial_dims`: The number of spatial dimensions. + - `num_points`: The number of points in each spatial dimension. **Returns:** - - `shape`: The shape of the spatial field array. + + - `shape`: The shape of the spatial axes of a state array in Fourier space. """ return (num_points,) * (num_spatial_dims - 1) + (num_points // 2 + 1,) @@ -238,16 +280,46 @@ def low_pass_filter_mask( Create a low-pass filter mask in Fourier space. **Arguments:** - - `num_spatial_dims`: The number of spatial dimensions. - - `num_points`: The number of points in each spatial dimension. - - `cutoff`: The cutoff wavenumber. This is inclusive. - - `axis_separate`: Whether to apply the cutoff to each axis separately. - Default is `True`. - - `indexing`: The indexing scheme to use for `jax.numpy.meshgrid`. - Either `"ij"` or `"xy"`. Default is `"ij"`. + + - `num_spatial_dims`: The number of spatial dimensions. + - `num_points`: The number of points in each spatial dimension. + - `cutoff`: The cutoff wavenumber. This is inclusive. + - `axis_separate`: Whether to apply the cutoff to each axis separately. + If `True` (default) the low-pass chunk is a hypercube in Fourier space. + If `False`, the low-pass chunk is a sphere in Fourier space. Only + relevant for `num_spatial_dims` >= 2. + - `indexing`: The indexing scheme to use for `jax.numpy.meshgrid`. + Either `"ij"` or `"xy"`. Default is `"ij"`. **Returns:** - - `mask`: The low-pass filter mask, shape `(1, ..., N//2+1)`. + + - `mask`: The low-pass filter mask. + + + !!! example + In 1D with 10 points, a cutoff of 3 will produce the mask + + ```python + + array([[ True, True, True, True, False, False]]) + + ``` + + To better understand this, let's produce the corresponding wavenumbers: + + ```python + + wn = exponax.spectral.build_wavenumbers(1, 10) + + print(wn) + + # array([[0, 1, 2, 3, 4, 5]]) + + ``` + + There are 6 wavenumbers in total (because this equals `(N//2)+1`), the + zeroth wavenumber is the mean mode, and then the mask includes the next + three wavenumbers because its **`cutoff` is inclusive**. """ wavenumbers = build_wavenumbers(num_spatial_dims, num_points, indexing=indexing) @@ -272,11 +344,29 @@ def oddball_filter_mask( Nyquist mode if the number of degrees of freedom is even. **Arguments:** - - `num_spatial_dims`: The number of spatial dimensions. - - `num_points`: The number of points in each spatial dimension. + + - `num_spatial_dims`: The number of spatial dimensions. + - `num_points`: The number of points in each spatial dimension. **Returns:** - - `mask`: The oddball filter mask, shape `(1, ..., N//2+1)`. + + - `mask`: The oddball filter mask which is `True` for all wavenumbers except + the Nyquist mode if the number of degrees of freedom is even. + + !!! example + ```python + + mask_even = exponax.spectral.oddball_filter_mask(1, 6) + + # array([[ True, True, True, False]]) + + mask_odd = exponax.spectral.oddball_filter_mask(1, 7) + + # array([[ True, True, True, True]]) + + ``` + + For higher-dimensional examples, see `tests/test_filter_masks.py`. !!! info For more background on why this is needed, see @@ -297,17 +387,11 @@ def oddball_filter_mask( return low_pass_filter_mask( num_spatial_dims, num_points, + # The cutoff is **inclusive** cutoff=mode_below_nyquist - 1, axis_separate=True, ) - # # Todo: Do we need the below? - # wavenumbers = build_wavenumbers(D, N, scaled=False) - # mask = True - # for wn_grid in wavenumbers: - # mask = mask & (wn_grid != -mode_below_nyquist) - # return mask - def build_scaling_array( num_spatial_dims: int, @@ -317,16 +401,23 @@ def build_scaling_array( ) -> Float[Array, "1 ... (N//2)+1"]: """ Creates an array of the values that would be seen in the result of a - (real-valued) Fourier transform of a signal of amplitude 1. + (real-valued) Fourier transform of signal which has amplitude 1 in all + resolvable wavenumbers. + + This can be used to counteract the scaling that is applied by the FFT + assuming one uses the default `norm="backward"`, which is also the default + done by the `Exponax` wrapped routines `exponax.fft` and `exponax.ifft`. **Arguments:** - - `num_spatial_dims`: The number of spatial dimensions. - - `num_points`: The number of points in each spatial dimension. - - `indexing`: The indexing scheme to use for `jax.numpy.meshgrid`. - Either `"ij"` or `"xy"`. Default is `"ij"`. + + - `num_spatial_dims`: The number of spatial dimensions. + - `num_points`: The number of points in each spatial dimension. + - `indexing`: The indexing scheme to use for `jax.numpy.meshgrid`. + Either `"ij"` or `"xy"`. Default is `"ij"`. **Returns:** - - `scaling`: The scaling array, shape `(1, ..., N//2+1)`. + + - `scaling`: The scaling array. """ right_most_wavenumbers = jnp.fft.rfftfreq(num_points, 1 / num_points) other_wavenumbers = jnp.fft.fftfreq(num_points, 1 / num_points) @@ -383,6 +474,22 @@ def build_reconstructional_scaling_array( """ Similar to `build_scaling_array`, but corresponds to the scaling observed when reconstructing a signal from its Fourier transform. + + This is different because it accounts for the fact `Exponax` uses the `rfft` + which only contributes half the coefficient magnitude for the axis which is + (approximately) halved in the Fourier space. A difference to `build_scaling_array` + can only be observed if `num_spatial_dims >= 2`. + + **Arguments:** + + - `num_spatial_dims`: The number of spatial dimensions. + - `num_points`: The number of points in each spatial dimension. + - `indexing`: The indexing scheme to use for `jax.numpy.meshgrid`. + Either `"ij"` or `"xy"`. Default is `"ij"`. + + **Returns:** + + - `scaling`: The reconstructional scaling array. """ right_most_wavenumbers = jnp.fft.rfftfreq(num_points, 1 / num_points) other_wavenumbers = jnp.fft.fftfreq(num_points, 1 / num_points) From 22f19bdbbae1297dad3f65b354a4b12425e69400 Mon Sep 17 00:00:00 2001 From: Felix Koehler Date: Thu, 5 Sep 2024 08:49:24 +0200 Subject: [PATCH 08/25] Merge common subexpressions --- exponax/_spectral.py | 106 +++++++++++++++++-------------------------- 1 file changed, 42 insertions(+), 64 deletions(-) diff --git a/exponax/_spectral.py b/exponax/_spectral.py index fdb876d..5044af2 100644 --- a/exponax/_spectral.py +++ b/exponax/_spectral.py @@ -1,5 +1,5 @@ from itertools import product -from typing import Optional, TypeVar, Union +from typing import Literal, Optional, TypeVar, Union import jax import jax.numpy as jnp @@ -393,31 +393,19 @@ def oddball_filter_mask( ) -def build_scaling_array( +def _build_scaling_array( num_spatial_dims: int, num_points: int, *, + others_fraction: Literal[2, 1], indexing: str = "ij", -) -> Float[Array, "1 ... (N//2)+1"]: +): """ - Creates an array of the values that would be seen in the result of a - (real-valued) Fourier transform of signal which has amplitude 1 in all - resolvable wavenumbers. - - This can be used to counteract the scaling that is applied by the FFT - assuming one uses the default `norm="backward"`, which is also the default - done by the `Exponax` wrapped routines `exponax.fft` and `exponax.ifft`. - - **Arguments:** + Shared routine between `build_scaling_array` and + `build_reconstructional_scaling_array`. - - `num_spatial_dims`: The number of spatial dimensions. - - `num_points`: The number of points in each spatial dimension. - - `indexing`: The indexing scheme to use for `jax.numpy.meshgrid`. - Either `"ij"` or `"xy"`. Default is `"ij"`. - - **Returns:** - - - `scaling`: The scaling array. + The `others_fraction` argument is used to determine the scaling of the + wavenumbers in the spatial dimensions that are **not** the right-most one. """ right_most_wavenumbers = jnp.fft.rfftfreq(num_points, 1 / num_points) other_wavenumbers = jnp.fft.fftfreq(num_points, 1 / num_points) @@ -430,7 +418,7 @@ def build_scaling_array( other_scaling = jnp.where( other_wavenumbers == 0, num_points, - num_points / 2, + num_points / others_fraction, # Only difference ) # If N is even, special treatment for the Nyquist mode @@ -465,6 +453,37 @@ def build_scaling_array( return scaling +def build_scaling_array( + num_spatial_dims: int, + num_points: int, + *, + indexing: str = "ij", +) -> Float[Array, "1 ... (N//2)+1"]: + """ + Creates an array of the values that would be seen in the result of a + (real-valued) Fourier transform of signal which has amplitude 1 in all + resolvable wavenumbers. + + This can be used to counteract the scaling that is applied by the FFT + assuming one uses the default `norm="backward"`, which is also the default + done by the `Exponax` wrapped routines `exponax.fft` and `exponax.ifft`. + + **Arguments:** + + - `num_spatial_dims`: The number of spatial dimensions. + - `num_points`: The number of points in each spatial dimension. + - `indexing`: The indexing scheme to use for `jax.numpy.meshgrid`. + Either `"ij"` or `"xy"`. Default is `"ij"`. + + **Returns:** + + - `scaling`: The scaling array. + """ + return _build_scaling_array( + num_spatial_dims, num_points, others_fraction=2, indexing=indexing + ) + + def build_reconstructional_scaling_array( num_spatial_dims: int, num_points: int, @@ -491,51 +510,10 @@ def build_reconstructional_scaling_array( - `scaling`: The reconstructional scaling array. """ - right_most_wavenumbers = jnp.fft.rfftfreq(num_points, 1 / num_points) - other_wavenumbers = jnp.fft.fftfreq(num_points, 1 / num_points) - - right_most_scaling = jnp.where( - right_most_wavenumbers == 0, - num_points, - num_points / 2, - ) - other_scaling = jnp.where( - other_wavenumbers == 0, - num_points, - num_points, # This is the only difference to `build_scaling_array` - ) - - # If N is even, special treatment for the Nyquist mode - if num_points % 2 == 0: - # rfft has the Nyquist mode as positive wavenumber - right_most_scaling = jnp.where( - right_most_wavenumbers == num_points // 2, - num_points, - right_most_scaling, - ) - # standard fft has the Nyquist mode as negative wavenumber - other_scaling = jnp.where( - other_wavenumbers == -num_points // 2, - num_points, - other_scaling, - ) - - scaling_list = [ - other_scaling, - ] * (num_spatial_dims - 1) + [ - right_most_scaling, - ] - - scaling = jnp.prod( - jnp.stack( - jnp.meshgrid(*scaling_list, indexing=indexing), - ), - axis=0, - keepdims=True, + return _build_scaling_array( + num_spatial_dims, num_points, others_fraction=1, indexing=indexing ) - return scaling - def get_modes_slices( num_spatial_dims: int, num_points: int From 84a0023b2ebe4b5ef990562637bd145b5bdeb094 Mon Sep 17 00:00:00 2001 From: Felix Koehler Date: Thu, 5 Sep 2024 09:06:20 +0200 Subject: [PATCH 09/25] Extend and fix documentation --- exponax/_spectral.py | 183 +++++++++++++++++++++++++++++++++---------- 1 file changed, 143 insertions(+), 40 deletions(-) diff --git a/exponax/_spectral.py b/exponax/_spectral.py index 5044af2..7ee8d16 100644 --- a/exponax/_spectral.py +++ b/exponax/_spectral.py @@ -519,8 +519,60 @@ def get_modes_slices( num_spatial_dims: int, num_points: int ) -> tuple[tuple[slice, ...], ...]: """ - Produces a list of list of slices corresponding to all positive and negative - wavenumber blocks found in the representation of a state in Fourier space. + Produces a tuple of tuple of slices corresponding to all positive and + negative wavenumber blocks found in the representation of a state in Fourier + space. + + **Arguments:** + + - `num_spatial_dims`: The number of spatial dimensions. + - `num_points`: The number of points in each spatial dimension. + + **Returns:** + + - `all_modes_slices`: The tuple of tuple of slices. The outer tuple has + `2^(D-1)` entries if `D` is the number of spatial dimensions. Each inner + tuple has `D+1` entries. + + !!! example + In 1D, there is only one block of coefficients in Fourier space; those + associated with the positive wavenumbers. The additional `slice(None)` + in the beginning is for the channel axis. + + ```python + + slices = exponax.spectral.get_modes_slices(1, 10) + + print(slices) + + # ( + + # (slice(None), slice(None, 6)), + + # ) + + ``` + + In 2D, there are two blocks of coefficients; one for the positive + wavenumbers and one for the negative wavenumbers along the first axis + (which cannot be halved because the `rfft` already acts on the last, the + second spatial axis). + + ```python + + slices = exponax.spectral.get_modes_slices(2, 10) + + print(slices) + + # ( + + # (slice(None), slice(None, 5), slice(None, 6)), + + # (slice(None), slice(-5, None), slice(None, 6)), + + # ) + + ``` """ is_even = num_points % 2 == 0 nyquist_mode = num_points // 2 @@ -555,27 +607,36 @@ def fft( """ Perform a **real-valued** FFT of a field. This function is designed for states in `Exponax` with a leading channel axis and then one, two, or three - following spatial axes, **each of the same length** N. + subsequent spatial axes, **each of the same length** N. Only accepts real-valued input fields and performs a real-valued FFT. Hence, the last axis of the returned field is of length N//2+1. + !!! warning + The argument `num_spatial_dims` can only be correctly inferred if the + array follows the Exponax convention, e.g., no leading batch axis. For a + batched operation, use `jax.vmap` on this function. + **Arguments:** - - `field`: The field to transform, shape `(C, ..., N,)`. - - `num_spatial_dims`: The number of spatial dimensions, i.e., how many - spatial axes follow the channel axis. Can be inferred from the array - if it follows the Exponax convention. For example, it is not allowed - to have a leading batch axis, in such a case use `jax.vmap` on this - function. + + - `field`: The state to transform. + - `num_spatial_dims`: The number of spatial dimensions, i.e., how many + spatial axes follow the channel axis. Can be inferred from the array if + it follows the Exponax convention. For example, it is not allowed to + have a leading batch axis, in such a case use `jax.vmap` on this + function. **Returns:** - - `field_hat`: The transformed field, shape `(C, ..., N//2+1)`. + + - `field_hat`: The transformed field, shape `(C, ..., N//2+1)`. !!! info Internally uses `jax.numpy.fft.rfftn` with the default settings for the `norm` argument with `norm="backward"`. This means that the forward FFT (this function) does not apply any normalization to the result, only the - [`exponax.ifft`][] function applies normalization. + [`exponax.ifft`][] function applies normalization. To extract the + amplitude of the coefficients divide by + `expoanx.spectral.build_scaling_array`. """ if num_spatial_dims is None: num_spatial_dims = field.ndim - 1 @@ -591,30 +652,38 @@ def ifft( ) -> Float[Array, "C ... N"]: """ Perform the inverse **real-valued** FFT of a field. This is the inverse - operation of `fft`. This function is designed for states in `Exponax` with a - leading channel axis and then one, two, or three following spatial axes. In - state space all spatial axes have the same length N (here called - `num_points`). + operation of `exponax.fft`. This function is designed for states in + `Exponax` with a leading channel axis and then one, two, or three following + spatial axes. In state space all spatial axes have the same length N (here + called `num_points`). Requires a complex-valued field in Fourier space with the last axis of length N//2+1. - The number of points (N, or `num_points`) must be provided if the number of - spatial dimensions is 1. Otherwise, it can be inferred from the shape of the - field. + !!! info + The number of points (N, or `num_points`) must be provided if the number + of spatial dimensions is 1. Otherwise, it can be inferred from the shape + of the field. + + !!! warning + The argument `num_spatial_dims` can only be correctly inferred if the + array follows the Exponax convention, e.g., no leading batch axis. For a + batched operation, use `jax.vmap` on this function. **Arguments:** - - `field_hat`: The transformed field, shape `(C, ..., N//2+1)`. - - `num_spatial_dims`: The number of spatial dimensions, i.e., how many - spatial axes follow the channel axis. Can be inferred from the array - if it follows the Exponax convention. For example, it is not allowed - to have a leading batch axis, in such a case use `jax.vmap` on this - function. - - `num_points`: The number of points in each spatial dimension. Can be - inferred if `num_spatial_dims` >= 2 + + - `field_hat`: The transformed field, shape `(C, ..., N//2+1)`. + - `num_spatial_dims`: The number of spatial dimensions, i.e., how many + spatial axes follow the channel axis. Can be inferred from the array if + it follows the Exponax convention. For example, it is not allowed to + have a leading batch axis, in such a case use `jax.vmap` on this + function. + - `num_points`: The number of points in each spatial dimension. Can be + inferred if `num_spatial_dims` >= 2 **Returns:** - - `field`: The transformed field, shape `(C, ..., N,)`. + + - `field`: The state in physical space, shape `(C, ..., N,)`. !!! info Internally uses `jax.numpy.fft.irfftn` with the default settings for the @@ -660,24 +729,32 @@ def derivative( the number of degrees of freedom N is even. For this, consider also using the order option. + !!! warning + The argument `num_spatial_dims` can only be correctly inferred if the + array follows the Exponax convention, e.g., no leading batch axis. For a + batched operation, use `jax.vmap` on this function. + **Arguments:** - - `field`: The field to differentiate, shape `(C, ..., N,)`. `C` can be - `1` for a scalar field or `D` for a vector field. - - `L`: The domain extent. - - `order`: The order of the derivative. Default is `1`. - - `indexing`: The indexing scheme to use for `jax.numpy.meshgrid`. - Either `"ij"` or `"xy"`. Default is `"ij"`. + + - `field`: The field to differentiate, shape `(C, ..., N,)`. `C` can be + `1` for a scalar field or `D` for a vector field. + - `domain_extent`: The size of the domain `L`; in higher dimensions + the domain is assumed to be a scaled hypercube `Ω = (0, L)ᵈ`. + - `order`: The order of the derivative. Default is `1`. + - `indexing`: The indexing scheme to use for `jax.numpy.meshgrid`. + Either `"ij"` or `"xy"`. Default is `"ij"`. **Returns:** - - `field_der`: The derivative of the field, shape `(C, D, ..., - (N//2)+1)` or `(D, ..., (N//2)+1)`. + + - `field_der`: The derivative of the field, shape `(C, D, ..., + (N//2)+1)` or `(D, ..., (N//2)+1)`. """ channel_shape = field.shape[0] spatial_shape = field.shape[1:] - D = len(spatial_shape) - N = spatial_shape[0] + num_spatial_dims = len(spatial_shape) + num_points = spatial_shape[0] derivative_operator = build_derivative_operator( - D, domain_extent, N, indexing=indexing + num_spatial_dims, domain_extent, num_points, indexing=indexing ) # # I decided to not use this fix @@ -687,7 +764,7 @@ def derivative( # ) derivative_operator_fixed = derivative_operator**order - field_hat = fft(field, num_spatial_dims=D) + field_hat = fft(field, num_spatial_dims=num_spatial_dims) if channel_shape == 1: # Do not introduce another channel axis field_der_hat = derivative_operator_fixed * field_hat @@ -695,7 +772,9 @@ def derivative( # Create a "derivative axis" right after the channel axis field_der_hat = field_hat[:, None] * derivative_operator_fixed[None, ...] - field_der = ifft(field_der_hat, num_spatial_dims=D, num_points=N) + field_der = ifft( + field_der_hat, num_spatial_dims=num_spatial_dims, num_points=num_points + ) return field_der @@ -705,6 +784,30 @@ def make_incompressible( *, indexing: str = "ij", ): + """ + Makes a velocity field incompressible by solving the associated pressure + Poisson equation and subtract the pressure gradient. + + With the divergence of the velocity field as the right-hand side, solve the + Poisson equation for pressure `p` + + Δp = - ∇ ⋅ v⃗ + + and then correct the velocity field to be incompressible + + v⃗ ← v⃗ - ∇p + + **Arguments:** + + - `field`: The velocity field to make incompressible, shape `(D, ..., N,)`. + Must have as many channel dimensions as spatial axes. + - `indexing`: The indexing scheme to use for `jax.numpy.meshgrid`. + + **Returns:** + + - `incompressible_field`: The incompressible velocity field, shape `(D, ..., + N,)`. + """ channel_shape = field.shape[0] spatial_shape = field.shape[1:] num_spatial_dims = len(spatial_shape) From fb609ad5ae5a39002f6d17b3f2af1a4ac4d325f8 Mon Sep 17 00:00:00 2001 From: Felix Koehler Date: Thu, 5 Sep 2024 09:29:42 +0200 Subject: [PATCH 10/25] Add experimental convinience function to extract the Fourier coefficients --- exponax/_spectral.py | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/exponax/_spectral.py b/exponax/_spectral.py index 7ee8d16..3cf8581 100644 --- a/exponax/_spectral.py +++ b/exponax/_spectral.py @@ -850,6 +850,26 @@ def make_incompressible( return incompressible_field +def get_fourier_coefficients( + state: Float[Array, "C ... N"], +) -> Complex[Array, "C ... (N//2)+1"]: + """ + EXPERIMENTAL + + Get the Fourier coefficients of a state in Fourier space. + + **Arguments:** + + - `state`: The state, shape `(C, ..., N,)`. + + **Returns:** + + - `coefficients`: The Fourier coefficients, shape `(C, ..., N//2+1)`. + """ + state_hat = fft(state) + return state_hat / build_scaling_array(state.ndim - 1, state.shape[-1]) + + def get_power_spectrum( field: Float[Array, "C ... N"], ) -> Float[Array, "C (N//2)+1"]: From 55a1e1716e8199d13cf8856ad68f74973b8c4505 Mon Sep 17 00:00:00 2001 From: Felix Koehler Date: Thu, 5 Sep 2024 10:12:52 +0200 Subject: [PATCH 11/25] Unify the creation of scaling arrays --- exponax/_interpolation.py | 10 +- exponax/_spectral.py | 101 +++++++++++--------- exponax/ic/_gaussian_random_field.py | 4 +- exponax/ic/_truncated_fourier_series.py | 4 +- exponax/nonlin_fun/_vorticity_convection.py | 2 +- 5 files changed, 70 insertions(+), 51 deletions(-) diff --git a/exponax/_interpolation.py b/exponax/_interpolation.py index cdf612f..87d7e06 100644 --- a/exponax/_interpolation.py +++ b/exponax/_interpolation.py @@ -9,7 +9,6 @@ from jaxtyping import Array, Complex, Float from ._spectral import ( - build_reconstructional_scaling_array, build_scaled_wavenumbers, build_scaling_array, fft, @@ -84,8 +83,11 @@ def __init__( self.num_points = state.shape[-1] self.state_hat_scaled = fft(state, num_spatial_dims=self.num_spatial_dims) / ( - build_reconstructional_scaling_array( - self.num_spatial_dims, self.num_points, indexing=indexing + build_scaling_array( + self.num_spatial_dims, + self.num_points, + mode="reconstruction", + indexing=indexing, ) ) self.wavenumbers = build_scaled_wavenumbers( @@ -242,6 +244,7 @@ def map_between_resolutions( ) / build_scaling_array( num_spatial_dims, old_num_points, + mode="norm_compensation", ) if new_num_points > old_num_points: @@ -269,6 +272,7 @@ def map_between_resolutions( new_state_hat = new_state_hat_scaled * build_scaling_array( num_spatial_dims, new_num_points, + mode="norm_compensation", ) if old_num_points > new_num_points: # Downscaling diff --git a/exponax/_spectral.py b/exponax/_spectral.py index 3cf8581..5cec44b 100644 --- a/exponax/_spectral.py +++ b/exponax/_spectral.py @@ -397,7 +397,8 @@ def _build_scaling_array( num_spatial_dims: int, num_points: int, *, - others_fraction: Literal[2, 1], + right_most_scaling_denominator: Literal[2, 1], + others_scaling_denominator: Literal[2, 1], indexing: str = "ij", ): """ @@ -413,12 +414,12 @@ def _build_scaling_array( right_most_scaling = jnp.where( right_most_wavenumbers == 0, num_points, - num_points / 2, + num_points / right_most_scaling_denominator, ) other_scaling = jnp.where( other_wavenumbers == 0, num_points, - num_points / others_fraction, # Only difference + num_points / others_scaling_denominator, # Only difference ) # If N is even, special treatment for the Nyquist mode @@ -457,21 +458,37 @@ def build_scaling_array( num_spatial_dims: int, num_points: int, *, + mode: Literal["norm_compensation", "reconstruction", "coef_extraction"], indexing: str = "ij", ) -> Float[Array, "1 ... (N//2)+1"]: """ - Creates an array of the values that would be seen in the result of a - (real-valued) Fourier transform of signal which has amplitude 1 in all - resolvable wavenumbers. - - This can be used to counteract the scaling that is applied by the FFT - assuming one uses the default `norm="backward"`, which is also the default - done by the `Exponax` wrapped routines `exponax.fft` and `exponax.ifft`. + When `exponax.fft` is used, the resulting array in Fourier space represents + a scaled version of the Fourier coefficients. Use this function to produce + arrays to counteract this scaling based on the task. + + 1. `"norm_compensation"`: The scaling is exactly the scaling the + `exponax.ifft` applies. + 2. `"reconstruction"`: Technically `"norm_compensation"` should provide an + array of coefficients that can be used to build a Fourier interpolant + (i.e., what `exponax.FourierInterpolator` does). However, since + `exponax.fft` uses the real-valued FFT, there is only half of the + contribution for the coefficients along the right-most axis. This mode + provides the scaling to counteract this. + 3. `"coef_extraction"`: Any of the former modes (in higher dimensions) does + not produce the same coefficients as the amplitude in the physical space + (because there is a coefficient contribution both in the positive and + negative wavenumber). For example, if the signal `3 * cos(2x)` was + discretized on the domain `[0, 2pi]` with 10 points, the amplitude of + the Fourier coefficient at the 2nd wavenumber would be `3/2` if rescaled + with mode `"norm_compensation"`. This mode provides the scaling to + extract the correct coefficients. **Arguments:** - `num_spatial_dims`: The number of spatial dimensions. - `num_points`: The number of points in each spatial dimension. + - `mode`: The mode of the scaling array. Either `"norm_compensation"`, + `"reconstruction"`, or `"coef_extraction"`. - `indexing`: The indexing scheme to use for `jax.numpy.meshgrid`. Either `"ij"` or `"xy"`. Default is `"ij"`. @@ -479,40 +496,32 @@ def build_scaling_array( - `scaling`: The scaling array. """ - return _build_scaling_array( - num_spatial_dims, num_points, others_fraction=2, indexing=indexing - ) - - -def build_reconstructional_scaling_array( - num_spatial_dims: int, - num_points: int, - *, - indexing: str = "ij", -) -> Float[Array, "1 ... (N//2)+1"]: - """ - Similar to `build_scaling_array`, but corresponds to the scaling observed - when reconstructing a signal from its Fourier transform. - - This is different because it accounts for the fact `Exponax` uses the `rfft` - which only contributes half the coefficient magnitude for the axis which is - (approximately) halved in the Fourier space. A difference to `build_scaling_array` - can only be observed if `num_spatial_dims >= 2`. - - **Arguments:** - - - `num_spatial_dims`: The number of spatial dimensions. - - `num_points`: The number of points in each spatial dimension. - - `indexing`: The indexing scheme to use for `jax.numpy.meshgrid`. - Either `"ij"` or `"xy"`. Default is `"ij"`. - - **Returns:** - - - `scaling`: The reconstructional scaling array. - """ - return _build_scaling_array( - num_spatial_dims, num_points, others_fraction=1, indexing=indexing - ) + if mode == "norm_compensation": + return _build_scaling_array( + num_spatial_dims, + num_points, + right_most_scaling_denominator=1, + others_scaling_denominator=1, + indexing=indexing, + ) + elif mode == "reconstruction": + return _build_scaling_array( + num_spatial_dims, + num_points, + right_most_scaling_denominator=2, + others_scaling_denominator=1, + indexing=indexing, + ) + elif mode == "coef_extraction": + return _build_scaling_array( + num_spatial_dims, + num_points, + right_most_scaling_denominator=2, + others_scaling_denominator=2, + indexing=indexing, + ) + else: + raise ValueError("Invalid mode.") def get_modes_slices( @@ -867,7 +876,9 @@ def get_fourier_coefficients( - `coefficients`: The Fourier coefficients, shape `(C, ..., N//2+1)`. """ state_hat = fft(state) - return state_hat / build_scaling_array(state.ndim - 1, state.shape[-1]) + return state_hat / build_scaling_array( + state.ndim - 1, state.shape[-1], mode="coef_extraction" + ) def get_power_spectrum( diff --git a/exponax/ic/_gaussian_random_field.py b/exponax/ic/_gaussian_random_field.py index f3f8ba6..1af80c4 100644 --- a/exponax/ic/_gaussian_random_field.py +++ b/exponax/ic/_gaussian_random_field.py @@ -80,7 +80,9 @@ def __call__( noise = noise * amplitude - noise = noise * build_scaling_array(self.num_spatial_dims, num_points) + noise = noise * build_scaling_array( + self.num_spatial_dims, num_points, mode="coef_extraction" + ) ic = ifft(noise, num_spatial_dims=self.num_spatial_dims, num_points=num_points) diff --git a/exponax/ic/_truncated_fourier_series.py b/exponax/ic/_truncated_fourier_series.py index ee4aad0..8dedd86 100644 --- a/exponax/ic/_truncated_fourier_series.py +++ b/exponax/ic/_truncated_fourier_series.py @@ -132,7 +132,9 @@ def __call__( ) fourier_noise = fourier_noise * build_scaling_array( - self.num_spatial_dims, num_points + self.num_spatial_dims, + num_points, + mode="coef_extraction", ) u = ifft( diff --git a/exponax/nonlin_fun/_vorticity_convection.py b/exponax/nonlin_fun/_vorticity_convection.py index 558f56f..efab8dd 100644 --- a/exponax/nonlin_fun/_vorticity_convection.py +++ b/exponax/nonlin_fun/_vorticity_convection.py @@ -171,7 +171,7 @@ def __init__( # `injection_mode`, because we apply the forcing on the vorticity. -injection_mode * injection_scale - * build_scaling_array(num_spatial_dims, num_points), + * build_scaling_array(num_spatial_dims, num_points, mode="coef_extraction"), 0.0, ) From d52a4806088c6960c4843c52a3942fdf44219182 Mon Sep 17 00:00:00 2001 From: Felix Koehler Date: Thu, 5 Sep 2024 10:13:14 +0200 Subject: [PATCH 12/25] Test scaling array for norm_compensation --- tests/test_spectral_scaling_arrays.py | 39 +++++++++++++++++++++++++++ 1 file changed, 39 insertions(+) create mode 100644 tests/test_spectral_scaling_arrays.py diff --git a/tests/test_spectral_scaling_arrays.py b/tests/test_spectral_scaling_arrays.py new file mode 100644 index 0000000..b47e393 --- /dev/null +++ b/tests/test_spectral_scaling_arrays.py @@ -0,0 +1,39 @@ +import jax +import jax.numpy as jnp +import pytest + +import exponax as ex + + +@pytest.mark.parametrize( + "num_spatial_dims,num_points", [(D, N) for D in [1, 2, 3] for N in [10, 11]] +) +def test_building_scaling_array_for_norm_compensation( + num_spatial_dims: int, num_points: int +): + noise = jax.random.normal( + jax.random.PRNGKey(0), (1,) + (num_points,) * num_spatial_dims + ) + + noise_hat_norm_backward = jnp.fft.rfftn( + noise, + axes=ex.spectral.space_indices(num_spatial_dims), + ) + noise_hat_norm_forward = jnp.fft.rfftn( + noise, + axes=ex.spectral.space_indices(num_spatial_dims), + norm="forward", + ) + + scaling_array = ex.spectral.build_scaling_array( + num_spatial_dims, + num_points, + mode="norm_compensation", + ) + + noise_hat_norm_backward_scaled = noise_hat_norm_backward / scaling_array + + assert noise_hat_norm_backward_scaled == pytest.approx(noise_hat_norm_forward) + + +# Mode "reconstruction" is already tested as part of the `test_interpolation.py`` From e96be3fbd90bf55abfd451c559aa68c569f312ef Mon Sep 17 00:00:00 2001 From: Felix Koehler Date: Thu, 5 Sep 2024 10:20:03 +0200 Subject: [PATCH 13/25] Add tests for coefficient extraction in 1D --- tests/test_spectral_scaling_arrays.py | 51 +++++++++++++++++++++++++++ 1 file changed, 51 insertions(+) diff --git a/tests/test_spectral_scaling_arrays.py b/tests/test_spectral_scaling_arrays.py index b47e393..82f4ba6 100644 --- a/tests/test_spectral_scaling_arrays.py +++ b/tests/test_spectral_scaling_arrays.py @@ -37,3 +37,54 @@ def test_building_scaling_array_for_norm_compensation( # Mode "reconstruction" is already tested as part of the `test_interpolation.py`` + + +def test_building_scaling_array_for_coef_extraction(): + # 1D + grid_1d = ex.make_grid(1, 2 * jnp.pi, 10) + + u = 3 * jnp.cos(2 * grid_1d) + u_hat = ex.fft(u) + u_hat_scaled = u_hat / ex.spectral.build_scaling_array( + 1, + 10, + mode="coef_extraction", + ) + assert u_hat_scaled.round(5)[0, 2] == pytest.approx(3.0 + 0.0j) + + u = 3.0 * jnp.ones_like(grid_1d) + u_hat = ex.fft(u) + u_hat_scaled = u_hat / ex.spectral.build_scaling_array( + 1, + 10, + mode="coef_extraction", + ) + assert u_hat_scaled.round(5)[0, 0] == pytest.approx(3.0 + 0.0j) + + u = 3.0 * jnp.sin(2 * grid_1d) + u_hat = ex.fft(u) + u_hat_scaled = u_hat / ex.spectral.build_scaling_array( + 1, + 10, + mode="coef_extraction", + ) + assert u_hat_scaled.round(5)[0, 2] == pytest.approx(0.0 - 3.0j) + + u = 3.0 * jnp.cos(5 * grid_1d) + u_hat = ex.fft(u) + u_hat_scaled = u_hat / ex.spectral.build_scaling_array( + 1, + 10, + mode="coef_extraction", + ) + assert u_hat_scaled.round(5)[0, 5] == pytest.approx(3.0 + 0.0j) + + u = 3.0 * jnp.sin(5 * grid_1d) + u_hat = ex.fft(u) + u_hat_scaled = u_hat / ex.spectral.build_scaling_array( + 1, + 10, + mode="coef_extraction", + ) + # Nyquist mode sine cannot be captured + assert u_hat_scaled.round(5)[0, 5] == pytest.approx(0.0 + 0.0j) From 6fd79f49315b9c5a427cbe322a890bd49735192a Mon Sep 17 00:00:00 2001 From: Felix Koehler Date: Thu, 5 Sep 2024 10:26:29 +0200 Subject: [PATCH 14/25] Test for coefficient extration in 2D --- tests/test_spectral_scaling_arrays.py | 125 ++++++++++++++++++++++++++ 1 file changed, 125 insertions(+) diff --git a/tests/test_spectral_scaling_arrays.py b/tests/test_spectral_scaling_arrays.py index 82f4ba6..099b2e5 100644 --- a/tests/test_spectral_scaling_arrays.py +++ b/tests/test_spectral_scaling_arrays.py @@ -88,3 +88,128 @@ def test_building_scaling_array_for_coef_extraction(): ) # Nyquist mode sine cannot be captured assert u_hat_scaled.round(5)[0, 5] == pytest.approx(0.0 + 0.0j) + + # 2D - single terms + grid_2d = ex.make_grid(2, 2 * jnp.pi, 10) + + u = 3.0 * jnp.cos(2 * grid_2d[0:1]) + u_hat = ex.fft(u) + u_hat_scaled = u_hat / ex.spectral.build_scaling_array( + 2, + 10, + mode="coef_extraction", + ) + assert u_hat_scaled.round(5)[0, 2, 0] == pytest.approx(3.0 + 0.0j) + + u = 3.0 * jnp.cos(2 * grid_2d[1:2]) + u_hat = ex.fft(u) + u_hat_scaled = u_hat / ex.spectral.build_scaling_array( + 2, + 10, + mode="coef_extraction", + ) + assert u_hat_scaled.round(5)[0, 0, 2] == pytest.approx(3.0 + 0.0j) + + u = 3.0 * jnp.ones_like(grid_2d[0:1]) + u_hat = ex.fft(u) + u_hat_scaled = u_hat / ex.spectral.build_scaling_array( + 2, + 10, + mode="coef_extraction", + ) + assert u_hat_scaled.round(5)[0, 0, 0] == pytest.approx(3.0 + 0.0j) + + u = 3.0 * jnp.sin(2 * grid_2d[0:1]) + u_hat = ex.fft(u) + u_hat_scaled = u_hat / ex.spectral.build_scaling_array( + 2, + 10, + mode="coef_extraction", + ) + assert u_hat_scaled.round(5)[0, 2, 0] == pytest.approx(0.0 - 3.0j) + + u = 3.0 * jnp.sin(2 * grid_2d[1:2]) + u_hat = ex.fft(u) + u_hat_scaled = u_hat / ex.spectral.build_scaling_array( + 2, + 10, + mode="coef_extraction", + ) + assert u_hat_scaled.round(5)[0, 0, 2] == pytest.approx(0.0 - 3.0j) + + u = 3.0 * jnp.cos(5 * grid_2d[0:1]) + u_hat = ex.fft(u) + u_hat_scaled = u_hat / ex.spectral.build_scaling_array( + 2, + 10, + mode="coef_extraction", + ) + assert u_hat_scaled.round(5)[0, 5, 0] == pytest.approx(3.0 + 0.0j) + + u = 3.0 * jnp.cos(5 * grid_2d[1:2]) + u_hat = ex.fft(u) + u_hat_scaled = u_hat / ex.spectral.build_scaling_array( + 2, + 10, + mode="coef_extraction", + ) + assert u_hat_scaled.round(5)[0, 0, 5] == pytest.approx(3.0 + 0.0j) + + u = 3.0 * jnp.sin(5 * grid_2d[0:1]) + u_hat = ex.fft(u) + u_hat_scaled = u_hat / ex.spectral.build_scaling_array( + 2, + 10, + mode="coef_extraction", + ) + # Nyquist mode sine cannot be captured + assert u_hat_scaled.round(5)[0, 5, 0] == pytest.approx(0.0 + 0.0j) + + u = 3.0 * jnp.sin(5 * grid_2d[1:2]) + u_hat = ex.fft(u) + u_hat_scaled = u_hat / ex.spectral.build_scaling_array( + 2, + 10, + mode="coef_extraction", + ) + # Nyquist mode sine cannot be captured + assert u_hat_scaled.round(5)[0, 0, 5] == pytest.approx(0.0 + 0.0j) + + # 2D - mixed terms + u = 3.0 * jnp.cos(2 * grid_2d[0:1]) * jnp.cos(2 * grid_2d[1:2]) + u_hat = ex.fft(u) + u_hat_scaled = u_hat / ex.spectral.build_scaling_array( + 2, + 10, + mode="coef_extraction", + ) + assert u_hat_scaled.round(5)[0, 2, 2] == pytest.approx(3.0 + 0.0j) + + u = 3.0 * jnp.sin(2 * grid_2d[0:1]) * jnp.sin(2 * grid_2d[1:2]) + u_hat = ex.fft(u) + u_hat_scaled = u_hat / ex.spectral.build_scaling_array( + 2, + 10, + mode="coef_extraction", + ) + assert u_hat_scaled.round(5)[0, 2, 2] == pytest.approx(-3.0 + 0.0j) + + u = 3.0 * jnp.cos(2 * grid_2d[0:1]) * jnp.sin(2 * grid_2d[1:2]) + u_hat = ex.fft(u) + u_hat_scaled = u_hat / ex.spectral.build_scaling_array( + 2, + 10, + mode="coef_extraction", + ) + assert u_hat_scaled.round(5)[0, 2, 2] == pytest.approx(0.0 - 3.0j) + + u = 3.0 * jnp.sin(2 * grid_2d[0:1]) * jnp.cos(2 * grid_2d[1:2]) + u_hat = ex.fft(u) + u_hat_scaled = u_hat / ex.spectral.build_scaling_array( + 2, + 10, + mode="coef_extraction", + ) + assert u_hat_scaled.round(5)[0, 2, 2] == pytest.approx(0.0 - 3.0j) + + # TODO: 3D From 30e51638db5b15fe2f1418a70b58441c153d87fe Mon Sep 17 00:00:00 2001 From: Felix Koehler Date: Thu, 5 Sep 2024 10:29:48 +0200 Subject: [PATCH 15/25] Enhance docstring --- exponax/_spectral.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/exponax/_spectral.py b/exponax/_spectral.py index 5cec44b..de84955 100644 --- a/exponax/_spectral.py +++ b/exponax/_spectral.py @@ -402,11 +402,7 @@ def _build_scaling_array( indexing: str = "ij", ): """ - Shared routine between `build_scaling_array` and - `build_reconstructional_scaling_array`. - - The `others_fraction` argument is used to determine the scaling of the - wavenumbers in the spatial dimensions that are **not** the right-most one. + Low-Level routine to build scaling arrays, prefer using `build_scaling_array`. """ right_most_wavenumbers = jnp.fft.rfftfreq(num_points, 1 / num_points) other_wavenumbers = jnp.fft.fftfreq(num_points, 1 / num_points) From bae8fe5ea11966780f7526747d4f7714c4f93ee2 Mon Sep 17 00:00:00 2001 From: Felix Koehler Date: Thu, 5 Sep 2024 10:32:17 +0200 Subject: [PATCH 16/25] Remove experimental function --- exponax/_spectral.py | 22 ---------------------- 1 file changed, 22 deletions(-) diff --git a/exponax/_spectral.py b/exponax/_spectral.py index de84955..a47b3c5 100644 --- a/exponax/_spectral.py +++ b/exponax/_spectral.py @@ -855,28 +855,6 @@ def make_incompressible( return incompressible_field -def get_fourier_coefficients( - state: Float[Array, "C ... N"], -) -> Complex[Array, "C ... (N//2)+1"]: - """ - EXPERIMENTAL - - Get the Fourier coefficients of a state in Fourier space. - - **Arguments:** - - - `state`: The state, shape `(C, ..., N,)`. - - **Returns:** - - - `coefficients`: The Fourier coefficients, shape `(C, ..., N//2+1)`. - """ - state_hat = fft(state) - return state_hat / build_scaling_array( - state.ndim - 1, state.shape[-1], mode="coef_extraction" - ) - - def get_power_spectrum( field: Float[Array, "C ... N"], ) -> Float[Array, "C (N//2)+1"]: From 7be9b693a44ad4ae4f3c4e23490a3b01f0d7d858 Mon Sep 17 00:00:00 2001 From: Felix Koehler Date: Thu, 5 Sep 2024 10:42:08 +0200 Subject: [PATCH 17/25] Fix spectrum function --- exponax/_spectral.py | 24 +++++++++++++++++------- 1 file changed, 17 insertions(+), 7 deletions(-) diff --git a/exponax/_spectral.py b/exponax/_spectral.py index a47b3c5..b000fc4 100644 --- a/exponax/_spectral.py +++ b/exponax/_spectral.py @@ -855,8 +855,10 @@ def make_incompressible( return incompressible_field -def get_power_spectrum( - field: Float[Array, "C ... N"], +def get_spectrum( + state: Float[Array, "C ... N"], + *, + power: bool = True, ) -> Float[Array, "C (N//2)+1"]: """ Preliminary function -> might not be working correctly ... :/ @@ -869,10 +871,15 @@ def get_power_spectrum( https://github.com/scaomath/torch-cfd/blob/8c64319272f7660a57c491d823384130823900fe/sfno/visualizations.py#L114 """ - num_spatial_dims = field.ndim - 1 - num_points = field.shape[-1] + num_spatial_dims = state.ndim - 1 + num_points = state.shape[-1] - field_hat = fft(field, num_spatial_dims=num_spatial_dims) + state_hat = fft(state, num_spatial_dims=num_spatial_dims) + state_hat_scaled = state_hat / build_scaling_array( + num_spatial_dims, + num_points, + mode="reconstruction", # because of rfft + ) wavenumbers_mesh = build_wavenumbers(num_spatial_dims, num_points) wavenumbers_1d = build_wavenumbers(1, num_points) @@ -880,7 +887,10 @@ def get_power_spectrum( dk = wavenumbers_1d[0, 1] - wavenumbers_1d[0, 0] - power = jnp.abs(field_hat) ** 2 * 1 / 2 + if power: + magnitude = 0.5 * jnp.abs(state_hat_scaled) ** 2 + else: + magnitude = jnp.abs(state_hat_scaled) spectrum = [] @@ -893,7 +903,7 @@ def power_in_bucket(p, k): return jnp.mean(p[mask]) for k in wavenumbers_1d[0, :]: - spectrum.append(jax.vmap(power_in_bucket, in_axes=(0, None))(power, k)) + spectrum.append(jax.vmap(power_in_bucket, in_axes=(0, None))(magnitude, k)) spectrum = jnp.stack(spectrum, axis=-1) # spectrum /= jnp.sum(spectrum, axis=-1, keepdims=True) From 92ff949acd3d85fc6efd75bc726e9794fd660c13 Mon Sep 17 00:00:00 2001 From: Felix Koehler Date: Thu, 5 Sep 2024 10:44:05 +0200 Subject: [PATCH 18/25] Export get_spectrum instead of make_incompressible --- exponax/__init__.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/exponax/__init__.py b/exponax/__init__.py index cf56e7d..b31fd5a 100644 --- a/exponax/__init__.py +++ b/exponax/__init__.py @@ -6,7 +6,7 @@ from ._forced_stepper import ForcedStepper from ._interpolation import FourierInterpolator, map_between_resolutions from ._repeated_stepper import RepeatedStepper -from ._spectral import derivative, fft, ifft, make_incompressible +from ._spectral import derivative, fft, get_spectrum, ifft from ._utils import ( build_ic_set, make_grid, @@ -26,7 +26,7 @@ "derivative", "fft", "ifft", - "make_incompressible", + "get_spectrum", "make_grid", "rollout", "repeat", From 8e0a8a8eca15645d700cc2fc5b05bc34a141c7d8 Mon Sep 17 00:00:00 2001 From: Felix Koehler Date: Thu, 5 Sep 2024 10:47:57 +0200 Subject: [PATCH 19/25] Adapt docs --- docs/api/utilities/derivatives.md | 2 +- docs/api/utilities/spectral.md | 10 +++++++++- 2 files changed, 10 insertions(+), 2 deletions(-) diff --git a/docs/api/utilities/derivatives.md b/docs/api/utilities/derivatives.md index d678942..1b2e480 100644 --- a/docs/api/utilities/derivatives.md +++ b/docs/api/utilities/derivatives.md @@ -4,4 +4,4 @@ --- -::: exponax.make_incompressible \ No newline at end of file +::: exponax.spectral.make_incompressible \ No newline at end of file diff --git a/docs/api/utilities/spectral.md b/docs/api/utilities/spectral.md index e90e9b5..b9d6828 100644 --- a/docs/api/utilities/spectral.md +++ b/docs/api/utilities/spectral.md @@ -4,4 +4,12 @@ --- -::: exponax.ifft \ No newline at end of file +::: exponax.ifft + +--- + +::: exponax.get_spectrum + +--- + +::: exponax.spectral.build_scaling_array \ No newline at end of file From ebd9674cc6b4992902c4e42a07285aa30f8eaf18 Mon Sep 17 00:00:00 2001 From: Felix Koehler Date: Thu, 5 Sep 2024 10:48:46 +0200 Subject: [PATCH 20/25] Use proper links --- exponax/_spectral.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/exponax/_spectral.py b/exponax/_spectral.py index b000fc4..a680809 100644 --- a/exponax/_spectral.py +++ b/exponax/_spectral.py @@ -466,8 +466,8 @@ def build_scaling_array( `exponax.ifft` applies. 2. `"reconstruction"`: Technically `"norm_compensation"` should provide an array of coefficients that can be used to build a Fourier interpolant - (i.e., what `exponax.FourierInterpolator` does). However, since - `exponax.fft` uses the real-valued FFT, there is only half of the + (i.e., what [`exponax.FourierInterpolator`][] does). However, since + [`exponax.fft`][] uses the real-valued FFT, there is only half of the contribution for the coefficients along the right-most axis. This mode provides the scaling to counteract this. 3. `"coef_extraction"`: Any of the former modes (in higher dimensions) does From 6777cb885aff463bc31e22bfeaf0b0088e88c523 Mon Sep 17 00:00:00 2001 From: Felix Koehler Date: Thu, 5 Sep 2024 10:54:57 +0200 Subject: [PATCH 21/25] Enhance docstring --- exponax/_spectral.py | 24 ++++++++++++++++++------ 1 file changed, 18 insertions(+), 6 deletions(-) diff --git a/exponax/_spectral.py b/exponax/_spectral.py index a680809..04bcf56 100644 --- a/exponax/_spectral.py +++ b/exponax/_spectral.py @@ -861,15 +861,27 @@ def get_spectrum( power: bool = True, ) -> Float[Array, "C (N//2)+1"]: """ - Preliminary function -> might not be working correctly ... :/ + Compute the Fourier spectrum of a state, either the power spectrum or the + amplitude spectrum. - If passed a vorticity field like produced by - `exponax.stepper.NavierStokesVorticity`, this will produce the enstrophy - spectrum. + !!! info + The returned array will always have two axes, no matter how many spatial + axes the input has. + + !!! info + If it is applied to a vorticity field, it produces the enstrophy spectrum. - Inspired by: - https://github.com/scaomath/torch-cfd/blob/8c64319272f7660a57c491d823384130823900fe/sfno/visualizations.py#L114 + **Arguments:** + + - `state`: The state to compute the spectrum of. The state must follow the `Exponax` + convention with a leading channel axis and then one, two, or three subsequent + spatial axes, **each of the same length** N. + - `power`: Whether to compute the power spectrum or the amplitude spectrum. Default + is `True` meaning the amplitude spectrum. + + **Returns:** + - `spectrum`: The spectrum of the state, shape `(C, (N//2)+1)`. """ num_spatial_dims = state.ndim - 1 num_points = state.shape[-1] From abb0a47433383335fdc43b058801f7d4db6ea986 Mon Sep 17 00:00:00 2001 From: Felix Koehler Date: Thu, 5 Sep 2024 12:29:40 +0200 Subject: [PATCH 22/25] Enhance docstring --- exponax/_spectral.py | 23 ++++++++++++++++------- 1 file changed, 16 insertions(+), 7 deletions(-) diff --git a/exponax/_spectral.py b/exponax/_spectral.py index 04bcf56..852bfc3 100644 --- a/exponax/_spectral.py +++ b/exponax/_spectral.py @@ -869,19 +869,28 @@ def get_spectrum( axes the input has. !!! info - If it is applied to a vorticity field, it produces the enstrophy spectrum. + If it is applied to a vorticity field with `power=True` (default), it + produces the enstrophy spectrum. **Arguments:** - - `state`: The state to compute the spectrum of. The state must follow the `Exponax` - convention with a leading channel axis and then one, two, or three subsequent - spatial axes, **each of the same length** N. - - `power`: Whether to compute the power spectrum or the amplitude spectrum. Default - is `True` meaning the amplitude spectrum. + - `state`: The state to compute the spectrum of. The state must follow the + `Exponax` convention with a leading channel axis and then one, two, or + three subsequent spatial axes, **each of the same length** N. + - `power`: Whether to compute the power spectrum or the amplitude spectrum. + Default is `True` meaning the amplitude spectrum. **Returns:** - `spectrum`: The spectrum of the state, shape `(C, (N//2)+1)`. + + !!! note + The binning in higher dimensions can sometimes be counterintuitive. For + example, on a 2D grid if mode `[2, 2]` is populated, this is not + represented in the 2-bin (i.e., when indexing the returning array of + this function at `[2]`), but in the 3-bin because its distance from the + center is `sqrt(2**2 + 2**2) = 2.8284...` which is not in the range of + the 2-bin `[1.5, 2.5)`. """ num_spatial_dims = state.ndim - 1 num_points = state.shape[-1] @@ -912,7 +921,7 @@ def power_in_bucket(p, k): mask = (wavenumbers_norm[0] >= lower_limit) & ( wavenumbers_norm[0] < upper_limit ) - return jnp.mean(p[mask]) + return jnp.sum(p[mask]) for k in wavenumbers_1d[0, :]: spectrum.append(jax.vmap(power_in_bucket, in_axes=(0, None))(magnitude, k)) From da402394e2387398247764652eae8cdbae88bd27 Mon Sep 17 00:00:00 2001 From: Felix Koehler Date: Thu, 5 Sep 2024 12:32:36 +0200 Subject: [PATCH 23/25] Test the spectrum creation --- tests/test_spectrum.py | 66 ++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 66 insertions(+) create mode 100644 tests/test_spectrum.py diff --git a/tests/test_spectrum.py b/tests/test_spectrum.py new file mode 100644 index 0000000..6bac997 --- /dev/null +++ b/tests/test_spectrum.py @@ -0,0 +1,66 @@ +import jax.numpy as jnp +import pytest + +import exponax as ex + + +def test_amplitude_spectrum(): + # 1D + grid_1d = ex.make_grid(1, 2 * jnp.pi, 128) + + u = 3.0 * jnp.sin(grid_1d) + spectrum = ex.get_spectrum(u, power=False) + assert spectrum[0, 1] == pytest.approx(3.0) + + u = 3.0 * jnp.cos(2 * grid_1d) + spectrum = ex.get_spectrum(u, power=False) + assert spectrum[0, 2] == pytest.approx(3.0) + + u = 3.0 * jnp.sin(3 * grid_1d) + 4.0 * jnp.cos(3 * grid_1d) + spectrum = ex.get_spectrum(u, power=False) + assert spectrum[0, 3] == pytest.approx(jnp.sqrt(3.0**2 + 4.0**2)) + + u = 3.0 * jnp.ones_like(grid_1d) + spectrum = ex.get_spectrum(u, power=False) + assert spectrum[0, 0] == pytest.approx(3.0) + + # 2D - single terms + grid_2d = ex.make_grid(2, 2 * jnp.pi, 48) + + u = 3.0 * jnp.sin(grid_2d[0:1]) + spectrum = ex.get_spectrum(u, power=False) + assert spectrum[0, 1] == pytest.approx(3.0) + + u = 3.0 * jnp.cos(2 * grid_2d[0:1]) + spectrum = ex.get_spectrum(u, power=False) + assert spectrum[0, 2] == pytest.approx(3.0) + + u = 3.0 * jnp.sin(3 * grid_2d[0:1]) + 4.0 * jnp.cos(3 * grid_2d[0:1]) + spectrum = ex.get_spectrum(u, power=False) + assert spectrum[0, 3] == pytest.approx(jnp.sqrt(3.0**2 + 4.0**2)) + + u = 3.0 * jnp.ones_like(grid_2d[0:1]) + spectrum = ex.get_spectrum(u, power=False) + assert spectrum[0, 0] == pytest.approx(3.0) + + u = 3.0 * jnp.sin(grid_2d[1:2]) + spectrum = ex.get_spectrum(u, power=False) + assert spectrum[0, 1] == pytest.approx(3.0) + + u = 3.0 * jnp.cos(2 * grid_2d[1:2]) + spectrum = ex.get_spectrum(u, power=False) + assert spectrum[0, 2] == pytest.approx(3.0) + + # 2D - mixed terms + u = 3.0 * jnp.sin(1 * grid_2d[0:1]) * jnp.cos(1 * grid_2d[1:2]) + spectrum = ex.get_spectrum(u, power=False) + assert spectrum[0, 1] == pytest.approx(3.0) + + u = 3.0 * jnp.sin(2 * grid_2d[0:1]) * jnp.cos(2 * grid_2d[1:2]) + spectrum = ex.get_spectrum(u, power=False) + # The amplitude is in the 3-bin because the wavenumber norm of [2, 2] is + # 2*sqrt(2) = 2.8284 which is not in the interval [1.5, 2.5). + assert spectrum[0, 3] == pytest.approx(3.0) + assert spectrum[0, 2] == pytest.approx(0.0, abs=1e-5) + + # TODO: 3D From bd1eb67d3a8d1e573c31e50ee78fa6671e363dd6 Mon Sep 17 00:00:00 2001 From: Felix Koehler Date: Thu, 5 Sep 2024 12:38:14 +0200 Subject: [PATCH 24/25] Restructure docstring --- exponax/_spectral.py | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) diff --git a/exponax/_spectral.py b/exponax/_spectral.py index 852bfc3..f519b8a 100644 --- a/exponax/_spectral.py +++ b/exponax/_spectral.py @@ -868,10 +868,6 @@ def get_spectrum( The returned array will always have two axes, no matter how many spatial axes the input has. - !!! info - If it is applied to a vorticity field with `power=True` (default), it - produces the enstrophy spectrum. - **Arguments:** - `state`: The state to compute the spectrum of. The state must follow the @@ -884,6 +880,17 @@ def get_spectrum( - `spectrum`: The spectrum of the state, shape `(C, (N//2)+1)`. + !!! tip + The spectrum is usually best presented with a logarithmic y-axis, either + as `plt.semiology` or `plt.loglog`. Sometimes it can be helpful to set + the spectrum below a threshold to zero to better visualize the relevant + parts of the spectrum. This can be done with `jnp.maximum(spectrum, + 1e-10)` for example. + + !!! info + If it is applied to a vorticity field with `power=True` (default), it + produces the enstrophy spectrum. + !!! note The binning in higher dimensions can sometimes be counterintuitive. For example, on a 2D grid if mode `[2, 2]` is populated, this is not From 375623e84acce5ca0f0384ac25be950f30e01d88 Mon Sep 17 00:00:00 2001 From: Felix Koehler Date: Thu, 5 Sep 2024 12:51:36 +0200 Subject: [PATCH 25/25] Tests helper utilities for fft setup --- tests/test_shape_utilties.py | 23 +++++++++++++++++++++++ 1 file changed, 23 insertions(+) create mode 100644 tests/test_shape_utilties.py diff --git a/tests/test_shape_utilties.py b/tests/test_shape_utilties.py new file mode 100644 index 0000000..61b3d6a --- /dev/null +++ b/tests/test_shape_utilties.py @@ -0,0 +1,23 @@ +import exponax as ex + + +def test_space_indices(): + assert ex.spectral.space_indices(1) == (-1,) + assert ex.spectral.space_indices(2) == (-2, -1) + assert ex.spectral.space_indices(3) == (-3, -2, -1) + + +def test_spatial_shape(): + assert ex.spectral.spatial_shape(1, 64) == (64,) + assert ex.spectral.spatial_shape(2, 64) == (64, 64) + assert ex.spectral.spatial_shape(3, 64) == (64, 64, 64) + + +def test_wavenumber_shape(): + assert ex.spectral.wavenumber_shape(1, 64) == (33,) + assert ex.spectral.wavenumber_shape(2, 64) == (64, 33) + assert ex.spectral.wavenumber_shape(3, 64) == (64, 64, 33) + + assert ex.spectral.wavenumber_shape(1, 65) == (33,) + assert ex.spectral.wavenumber_shape(2, 65) == (65, 33) + assert ex.spectral.wavenumber_shape(3, 65) == (65, 65, 33)