From 379e794c2e70365393f5d5d69adb51fe9bc3b963 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Felix=20K=C3=B6hler?= <27728103+Ceyron@users.noreply.github.com> Date: Thu, 5 Sep 2024 16:01:48 +0200 Subject: [PATCH] Spectral updates (#35) * Preliminary function to compute the spectrum * Potential improvement to spectrum computation * Add hint on enstrophy * Fix on wavenumber norm computation * Rename to better reflect its effect on oddball mode * Add tests on filter mask * Improve first half of spectral documentation * Merge common subexpressions * Extend and fix documentation * Add experimental convinience function to extract the Fourier coefficients * Unify the creation of scaling arrays * Test scaling array for norm_compensation * Add tests for coefficient extraction in 1D * Test for coefficient extration in 2D * Enhance docstring * Remove experimental function * Fix spectrum function * Export get_spectrum instead of make_incompressible * Adapt docs * Use proper links * Enhance docstring * Enhance docstring * Test the spectrum creation * Restructure docstring * Tests helper utilities for fft setup --- docs/api/utilities/derivatives.md | 2 +- docs/api/utilities/spectral.md | 10 +- exponax/__init__.py | 4 +- exponax/_interpolation.py | 16 +- exponax/_spectral.py | 661 ++++++++++++++------ exponax/ic/_gaussian_random_field.py | 4 +- exponax/ic/_truncated_fourier_series.py | 4 +- exponax/nonlin_fun/_vorticity_convection.py | 2 +- tests/test_filter_masks.py | 451 +++++++++++++ tests/test_shape_utilties.py | 23 + tests/test_spectral_scaling_arrays.py | 215 +++++++ tests/test_spectrum.py | 66 ++ 12 files changed, 1257 insertions(+), 201 deletions(-) create mode 100644 tests/test_filter_masks.py create mode 100644 tests/test_shape_utilties.py create mode 100644 tests/test_spectral_scaling_arrays.py create mode 100644 tests/test_spectrum.py diff --git a/docs/api/utilities/derivatives.md b/docs/api/utilities/derivatives.md index d678942..1b2e480 100644 --- a/docs/api/utilities/derivatives.md +++ b/docs/api/utilities/derivatives.md @@ -4,4 +4,4 @@ --- -::: exponax.make_incompressible \ No newline at end of file +::: exponax.spectral.make_incompressible \ No newline at end of file diff --git a/docs/api/utilities/spectral.md b/docs/api/utilities/spectral.md index e90e9b5..b9d6828 100644 --- a/docs/api/utilities/spectral.md +++ b/docs/api/utilities/spectral.md @@ -4,4 +4,12 @@ --- -::: exponax.ifft \ No newline at end of file +::: exponax.ifft + +--- + +::: exponax.get_spectrum + +--- + +::: exponax.spectral.build_scaling_array \ No newline at end of file diff --git a/exponax/__init__.py b/exponax/__init__.py index cf56e7d..b31fd5a 100644 --- a/exponax/__init__.py +++ b/exponax/__init__.py @@ -6,7 +6,7 @@ from ._forced_stepper import ForcedStepper from ._interpolation import FourierInterpolator, map_between_resolutions from ._repeated_stepper import RepeatedStepper -from ._spectral import derivative, fft, ifft, make_incompressible +from ._spectral import derivative, fft, get_spectrum, ifft from ._utils import ( build_ic_set, make_grid, @@ -26,7 +26,7 @@ "derivative", "fft", "ifft", - "make_incompressible", + "get_spectrum", "make_grid", "rollout", "repeat", diff --git a/exponax/_interpolation.py b/exponax/_interpolation.py index 412a46a..87d7e06 100644 --- a/exponax/_interpolation.py +++ b/exponax/_interpolation.py @@ -9,13 +9,12 @@ from jaxtyping import Array, Complex, Float from ._spectral import ( - build_reconstructional_scaling_array, build_scaled_wavenumbers, build_scaling_array, fft, get_modes_slices, ifft, - nyquist_filter_mask, + oddball_filter_mask, space_indices, wavenumber_shape, ) @@ -84,8 +83,11 @@ def __init__( self.num_points = state.shape[-1] self.state_hat_scaled = fft(state, num_spatial_dims=self.num_spatial_dims) / ( - build_reconstructional_scaling_array( - self.num_spatial_dims, self.num_points, indexing=indexing + build_scaling_array( + self.num_spatial_dims, + self.num_points, + mode="reconstruction", + indexing=indexing, ) ) self.wavenumbers = build_scaled_wavenumbers( @@ -242,12 +244,13 @@ def map_between_resolutions( ) / build_scaling_array( num_spatial_dims, old_num_points, + mode="norm_compensation", ) if new_num_points > old_num_points: # Upscaling if old_num_points % 2 == 0 and oddball_zero: - old_state_hat_scaled *= nyquist_filter_mask( + old_state_hat_scaled *= oddball_filter_mask( num_spatial_dims, old_num_points ) @@ -269,11 +272,12 @@ def map_between_resolutions( new_state_hat = new_state_hat_scaled * build_scaling_array( num_spatial_dims, new_num_points, + mode="norm_compensation", ) if old_num_points > new_num_points: # Downscaling if new_num_points % 2 == 0 and oddball_zero: - new_state_hat *= nyquist_filter_mask(num_spatial_dims, new_num_points) + new_state_hat *= oddball_filter_mask(num_spatial_dims, new_num_points) new_state = ifft( new_state_hat, diff --git a/exponax/_spectral.py b/exponax/_spectral.py index 042362e..f519b8a 100644 --- a/exponax/_spectral.py +++ b/exponax/_spectral.py @@ -1,10 +1,13 @@ from itertools import product -from typing import Optional, TypeVar, Union +from typing import Literal, Optional, TypeVar, Union +import jax import jax.numpy as jnp from jaxtyping import Array, Bool, Complex, Float +C = TypeVar("C") D = TypeVar("D") +N = TypeVar("N") def build_wavenumbers( @@ -19,14 +22,16 @@ def build_wavenumbers( `jax.numpy.fft.rfftn`. **Arguments:** - - `num_spatial_dims`: The number of spatial dimensions. - - `num_points`: The number of points in each spatial dimension. - - `indexing`: The indexing scheme to use for `jax.numpy.meshgrid`. - Either `"ij"` or `"xy"`. Default is `"ij"`. + + - `num_spatial_dims`: The number of spatial dimensions. + - `num_points`: The number of points in each spatial dimension. + - `indexing`: The indexing scheme to use for `jax.numpy.meshgrid`. + Either `"ij"` or `"xy"`. Default is `"ij"`. **Returns:** - - `wavenumbers`: An array of wavenumber integer coordinates, shape - `(D, ..., (N//2)+1)`. + + - `wavenumbers`: An array of wavenumber integer coordinates, shape + `(D, ..., (N//2)+1)`. """ right_most_wavenumbers = jnp.fft.rfftfreq(num_points, 1 / num_points) other_wavenumbers = jnp.fft.fftfreq(num_points, 1 / num_points) @@ -52,20 +57,26 @@ def build_scaled_wavenumbers( indexing: str = "ij", ) -> Float[Array, "D ... (N//2)+1"]: """ - Setup an array containing scaled wavenumbers associated with a - "num_spatial_dims"-dimensional rfft (real-valued FFT) - `jax.numpy.fft.rfftn`. Scaling is done by `2 * pi / L`. + Setup an array containing **scaled** wavenumbers associated with a + "num_spatial_dims"-dimensional rfft (real-valued FFT) `jax.numpy.fft.rfftn`. + Scaling is done by `2 * pi / L`. **Arguments:** - - `num_spatial_dims`: The number of spatial dimensions. - - `domain_extent`: The domain extent. - - `num_points`: The number of points in each spatial dimension. - - `indexing`: The indexing scheme to use for `jax.numpy.meshgrid`. - Either `"ij"` or `"xy"`. Default is `"ij"`. + + - `num_spatial_dims`: The number of spatial dimensions. + - `domain_extent`: The domain extent. + - `num_points`: The number of points in each spatial dimension. + - `indexing`: The indexing scheme to use for `jax.numpy.meshgrid`. + Either `"ij"` or `"xy"`. Default is `"ij"`. **Returns:** - - `wavenumbers`: An array of wavenumber integer coordinates, shape - `(D, ..., (N//2)+1)`. + + - `wavenumbers`: An array of wavenumber integer coordinates, shape + `(D, ..., (N//2)+1)`. + + !!! info + These correctly scaled wavenumbers are used to set up derivative + operators via `1j * wavenumbers`. """ scale = 2 * jnp.pi / domain_extent wavenumbers = build_wavenumbers(num_spatial_dims, num_points, indexing=indexing) @@ -83,15 +94,21 @@ def build_derivative_operator( Setup the derivative operator in Fourier space. **Arguments:** - - `D`: The number of spatial dimensions. - - `L`: The domain extent. - - `N`: The number of points in each spatial dimension. - - `indexing`: The indexing scheme to use for `jax.numpy.meshgrid`. - Either `"ij"` or `"xy"`. Default is `"ij"`. + + - `num_spatial_dims`: The number of spatial dimensions `d`. + - `domain_extent`: The size of the domain `L`; in higher dimensions + the domain is assumed to be a scaled hypercube `Ω = (0, L)ᵈ`. + - `num_points`: The number of points `N` used to discretize the + domain. This **includes** the left boundary point and **excludes** the + right boundary point. In higher dimensions; the number of points in each + dimension is the same. Hence, the total number of degrees of freedom is + `Nᵈ`. + - `indexing`: The indexing scheme to use for `jax.numpy.meshgrid`. **Returns:** - - `derivative_operator`: The derivative operator, shape `(D, ..., - N//2+1)`. + + - `derivative_operator`: The derivative operator in Fourier space + (complex-valued array) """ return 1j * build_scaled_wavenumbers( num_spatial_dims, domain_extent, num_points, indexing=indexing @@ -104,16 +121,27 @@ def build_laplace_operator( order: int = 2, ) -> Complex[Array, "1 ... (N//2)+1"]: """ - Given the derivative operator of [`build_derivative_operator`], return the - Laplace operator. + Given the derivative operator of + [`exponax.spectral.build_derivative_operator`], return the Laplace operator. + + In state space: + + Δ = ∇ ⋅ ∇ + + And in Fourier space: + + i² k⃗ᵀ k⃗ = - k⃗ᵀ k⃗ **Arguments:** - - `derivative_operator`: The derivative operator, shape `(D, ..., - N//2+1)`. - - `order`: The order of the Laplace operator. Default is `2`. + + - `derivative_operator`: The derivative operator in Fourier space. + - `order`: The order of the Laplace operator. Default is `2`. Use a higher + even number for "higher-order Laplacians". For example, `order=4` will + return the biharmonic operator (without spatial mixing). **Returns:** - - `laplace_operator`: The Laplace operator, shape `(1, ..., N//2+1)`. + + - `laplace_operator`: The Laplace operator in Fourier space. """ if order % 2 != 0: raise ValueError("Order must be even.") @@ -128,18 +156,31 @@ def build_gradient_inner_product_operator( order: int = 1, ) -> Complex[Array, "1 ... (N//2)+1"]: """ - Given the derivative operator of [`build_derivative_operator`] and a velocity - field, return the operator that computes the inner product of the gradient - with the velocity. + Given the derivative operator of [`build_derivative_operator`] and a + velocity vector, return the operator that computes the inner product of the + gradient with the velocity. + + In state space this is: + + c⃗ ⋅ ∇ + + And in Fourier space: + + c⃗ ⋅ i k⃗ **Arguments:** - - `derivative_operator`: The derivative operator, shape `(D, ..., - N//2+1)`. - - `velocity`: The velocity field, shape `(D,)`. - - `order`: The order of the gradient. Default is `1`. + + - `derivative_operator`: The derivative operator in Fourier space. + - `velocity`: The velocity vector, must be an array with one axis with as + many dimensions as the derivative operator has in its leading axis. + - `order`: The order of the gradient. Default is `1` which is the "regular + gradient". Use higher orders for higher-order gradients given in terms + elementwise products. For example, `order=3` will return `c⃗ ⋅ (∇ ⊙ ∇ ⊙ + ∇)` **Returns:** - - `operator`: The operator, shape `(1, ..., N//2+1)`. + + - `operator`: The operator in Fourier space. """ if order % 2 != 1: raise ValueError("Order must be odd.") @@ -158,52 +199,47 @@ def build_gradient_inner_product_operator( # Need to add singleton channel axis operator = operator[None, ...] - # Old form below - # # Need to move the channel/dimension axis last to enable autobroadcast over - # # the arbitrary number of spatial axes, Then we can move this singleton axis - # # back to the front - # operator = jnp.swapaxes( - # jnp.sum( - # velocity - # * jnp.swapaxes( - # derivative_operator**order, - # 0, - # -1, - # ), - # axis=-1, - # keepdims=True, - # ), - # 0, - # -1, - # ) - return operator def space_indices(num_spatial_dims: int) -> tuple[int, ...]: """ - Returns the indices within a field array that correspond to the spatial - dimensions. + Returns the axes indices within a state array that correspond to the spatial + axes. + + !!! example + For a 2D field array, the spatial indices are `(-2, -1)`. **Arguments:** - - `num_spatial_dims`: The number of spatial dimensions. + + - `num_spatial_dims`: The number of spatial dimensions. **Returns:** - - `indices`: The indices of the spatial dimensions. + + - `indices`: The indices of the spatial axes. """ return tuple(range(-num_spatial_dims, 0)) def spatial_shape(num_spatial_dims: int, num_points: int) -> tuple[int, ...]: """ - Returns the shape of a spatial field array. + Returns the shape of a spatial field array (without its leading channel + axis). This follows the `Exponax` convention that the resolution is + indentical in each dimension. + + !!! example + For a 2D field array with 64 points in each dimension, the spatial shape + is `(64, 64)`. For a 3D field array with 32 points in each dimension, + the spatial shape is `(32, 32, 32)`. **Arguments:** - - `num_spatial_dims`: The number of spatial dimensions. - - `num_points`: The number of points in each spatial dimension. + + - `num_spatial_dims`: The number of spatial dimensions. + - `num_points`: The number of points in each spatial dimension. **Returns:** - - `shape`: The shape of the spatial field array. + + - `shape`: The shape of the spatial field array. """ return (num_points,) * num_spatial_dims @@ -211,14 +247,23 @@ def spatial_shape(num_spatial_dims: int, num_points: int) -> tuple[int, ...]: def wavenumber_shape(num_spatial_dims: int, num_points: int) -> tuple[int, ...]: """ Returns the spatial shape of a field in Fourier space (assuming the usage of - rfft, `jax.numpy.fft.rfftn`). + `exponax.fft` which internall performs a real-valued fft + `jax.numpy.fft.rfftn`). + + !!! example + For a 2D field array with 64 points in each dimension, the wavenumber shape + is `(64, 33)`. For a 3D field array with 32 points in each dimension, + the spatial shape is `(32, 32, 17)`. For a 1D field array with 51 points, + the wavenumber shape is `(26,)`. **Arguments:** - - `num_spatial_dims`: The number of spatial dimensions. - - `num_points`: The number of points in each spatial dimension. + + - `num_spatial_dims`: The number of spatial dimensions. + - `num_points`: The number of points in each spatial dimension. **Returns:** - - `shape`: The shape of the spatial field array. + + - `shape`: The shape of the spatial axes of a state array in Fourier space. """ return (num_points,) * (num_spatial_dims - 1) + (num_points // 2 + 1,) @@ -235,16 +280,46 @@ def low_pass_filter_mask( Create a low-pass filter mask in Fourier space. **Arguments:** - - `num_spatial_dims`: The number of spatial dimensions. - - `num_points`: The number of points in each spatial dimension. - - `cutoff`: The cutoff wavenumber. This is inclusive. - - `axis_separate`: Whether to apply the cutoff to each axis separately. - Default is `True`. - - `indexing`: The indexing scheme to use for `jax.numpy.meshgrid`. - Either `"ij"` or `"xy"`. Default is `"ij"`. + + - `num_spatial_dims`: The number of spatial dimensions. + - `num_points`: The number of points in each spatial dimension. + - `cutoff`: The cutoff wavenumber. This is inclusive. + - `axis_separate`: Whether to apply the cutoff to each axis separately. + If `True` (default) the low-pass chunk is a hypercube in Fourier space. + If `False`, the low-pass chunk is a sphere in Fourier space. Only + relevant for `num_spatial_dims` >= 2. + - `indexing`: The indexing scheme to use for `jax.numpy.meshgrid`. + Either `"ij"` or `"xy"`. Default is `"ij"`. **Returns:** - - `mask`: The low-pass filter mask, shape `(1, ..., N//2+1)`. + + - `mask`: The low-pass filter mask. + + + !!! example + In 1D with 10 points, a cutoff of 3 will produce the mask + + ```python + + array([[ True, True, True, True, False, False]]) + + ``` + + To better understand this, let's produce the corresponding wavenumbers: + + ```python + + wn = exponax.spectral.build_wavenumbers(1, 10) + + print(wn) + + # array([[0, 1, 2, 3, 4, 5]]) + + ``` + + There are 6 wavenumbers in total (because this equals `(N//2)+1`), the + zeroth wavenumber is the mean mode, and then the mask includes the next + three wavenumbers because its **`cutoff` is inclusive**. """ wavenumbers = build_wavenumbers(num_spatial_dims, num_points, indexing=indexing) @@ -253,27 +328,50 @@ def low_pass_filter_mask( for wn_grid in wavenumbers: mask = mask & (jnp.abs(wn_grid) <= cutoff) else: - mask = jnp.linalg.norm(mask, axis=0) <= cutoff + mask = jnp.linalg.norm(wavenumbers, axis=0) <= cutoff mask = mask[jnp.newaxis, ...] return mask -def nyquist_filter_mask( +def oddball_filter_mask( num_spatial_dims: int, num_points: int, ) -> Bool[Array, "1 ... N"]: """ - Creates mask that if multiplied with a field in Fourier space will remove - the Nyquist mode. + Creates mask that if multiplied with a field in Fourier space remove the + Nyquist mode if the number of degrees of freedom is even. **Arguments:** - - `num_spatial_dims`: The number of spatial dimensions. - - `num_points`: The number of points in each spatial dimension. + + - `num_spatial_dims`: The number of spatial dimensions. + - `num_points`: The number of points in each spatial dimension. **Returns:** - - `mask`: The Nyquist filter mask, shape `(1, ..., N//2+1)`. + + - `mask`: The oddball filter mask which is `True` for all wavenumbers except + the Nyquist mode if the number of degrees of freedom is even. + + !!! example + ```python + + mask_even = exponax.spectral.oddball_filter_mask(1, 6) + + # array([[ True, True, True, False]]) + + mask_odd = exponax.spectral.oddball_filter_mask(1, 7) + + # array([[ True, True, True, True]]) + + ``` + + For higher-dimensional examples, see `tests/test_filter_masks.py`. + + !!! info + For more background on why this is needed, see + https://www.mech.kth.se/~mattias/simson-user-guide-v4.0.pdf section + 6.2.4 and https://math.mit.edu/~stevenj/fft-deriv.pdf """ if num_points % 2 == 1: # Odd number of degrees of freedom (no issue with the Nyquist mode) @@ -289,36 +387,22 @@ def nyquist_filter_mask( return low_pass_filter_mask( num_spatial_dims, num_points, + # The cutoff is **inclusive** cutoff=mode_below_nyquist - 1, axis_separate=True, ) - # # Todo: Do we need the below? - # wavenumbers = build_wavenumbers(D, N, scaled=False) - # mask = True - # for wn_grid in wavenumbers: - # mask = mask & (wn_grid != -mode_below_nyquist) - # return mask - -def build_scaling_array( +def _build_scaling_array( num_spatial_dims: int, num_points: int, *, + right_most_scaling_denominator: Literal[2, 1], + others_scaling_denominator: Literal[2, 1], indexing: str = "ij", -) -> Float[Array, "1 ... (N//2)+1"]: +): """ - Creates an array of the values that would be seen in the result of a - (real-valued) Fourier transform of a signal of amplitude 1. - - **Arguments:** - - `num_spatial_dims`: The number of spatial dimensions. - - `num_points`: The number of points in each spatial dimension. - - `indexing`: The indexing scheme to use for `jax.numpy.meshgrid`. - Either `"ij"` or `"xy"`. Default is `"ij"`. - - **Returns:** - - `scaling`: The scaling array, shape `(1, ..., N//2+1)`. + Low-Level routine to build scaling arrays, prefer using `build_scaling_array`. """ right_most_wavenumbers = jnp.fft.rfftfreq(num_points, 1 / num_points) other_wavenumbers = jnp.fft.fftfreq(num_points, 1 / num_points) @@ -326,12 +410,12 @@ def build_scaling_array( right_most_scaling = jnp.where( right_most_wavenumbers == 0, num_points, - num_points / 2, + num_points / right_most_scaling_denominator, ) other_scaling = jnp.where( other_wavenumbers == 0, num_points, - num_points / 2, + num_points / others_scaling_denominator, # Only difference ) # If N is even, special treatment for the Nyquist mode @@ -366,68 +450,134 @@ def build_scaling_array( return scaling -def build_reconstructional_scaling_array( +def build_scaling_array( num_spatial_dims: int, num_points: int, *, + mode: Literal["norm_compensation", "reconstruction", "coef_extraction"], indexing: str = "ij", ) -> Float[Array, "1 ... (N//2)+1"]: """ - Similar to `build_scaling_array`, but corresponds to the scaling observed - when reconstructing a signal from its Fourier transform. - """ - right_most_wavenumbers = jnp.fft.rfftfreq(num_points, 1 / num_points) - other_wavenumbers = jnp.fft.fftfreq(num_points, 1 / num_points) + When `exponax.fft` is used, the resulting array in Fourier space represents + a scaled version of the Fourier coefficients. Use this function to produce + arrays to counteract this scaling based on the task. + + 1. `"norm_compensation"`: The scaling is exactly the scaling the + `exponax.ifft` applies. + 2. `"reconstruction"`: Technically `"norm_compensation"` should provide an + array of coefficients that can be used to build a Fourier interpolant + (i.e., what [`exponax.FourierInterpolator`][] does). However, since + [`exponax.fft`][] uses the real-valued FFT, there is only half of the + contribution for the coefficients along the right-most axis. This mode + provides the scaling to counteract this. + 3. `"coef_extraction"`: Any of the former modes (in higher dimensions) does + not produce the same coefficients as the amplitude in the physical space + (because there is a coefficient contribution both in the positive and + negative wavenumber). For example, if the signal `3 * cos(2x)` was + discretized on the domain `[0, 2pi]` with 10 points, the amplitude of + the Fourier coefficient at the 2nd wavenumber would be `3/2` if rescaled + with mode `"norm_compensation"`. This mode provides the scaling to + extract the correct coefficients. - right_most_scaling = jnp.where( - right_most_wavenumbers == 0, - num_points, - num_points / 2, - ) - other_scaling = jnp.where( - other_wavenumbers == 0, - num_points, - num_points, # This is the only difference to `build_scaling_array` - ) + **Arguments:** - # If N is even, special treatment for the Nyquist mode - if num_points % 2 == 0: - # rfft has the Nyquist mode as positive wavenumber - right_most_scaling = jnp.where( - right_most_wavenumbers == num_points // 2, + - `num_spatial_dims`: The number of spatial dimensions. + - `num_points`: The number of points in each spatial dimension. + - `mode`: The mode of the scaling array. Either `"norm_compensation"`, + `"reconstruction"`, or `"coef_extraction"`. + - `indexing`: The indexing scheme to use for `jax.numpy.meshgrid`. + Either `"ij"` or `"xy"`. Default is `"ij"`. + + **Returns:** + + - `scaling`: The scaling array. + """ + if mode == "norm_compensation": + return _build_scaling_array( + num_spatial_dims, num_points, - right_most_scaling, + right_most_scaling_denominator=1, + others_scaling_denominator=1, + indexing=indexing, ) - # standard fft has the Nyquist mode as negative wavenumber - other_scaling = jnp.where( - other_wavenumbers == -num_points // 2, + elif mode == "reconstruction": + return _build_scaling_array( + num_spatial_dims, num_points, - other_scaling, + right_most_scaling_denominator=2, + others_scaling_denominator=1, + indexing=indexing, ) - - scaling_list = [ - other_scaling, - ] * (num_spatial_dims - 1) + [ - right_most_scaling, - ] - - scaling = jnp.prod( - jnp.stack( - jnp.meshgrid(*scaling_list, indexing=indexing), - ), - axis=0, - keepdims=True, - ) - - return scaling + elif mode == "coef_extraction": + return _build_scaling_array( + num_spatial_dims, + num_points, + right_most_scaling_denominator=2, + others_scaling_denominator=2, + indexing=indexing, + ) + else: + raise ValueError("Invalid mode.") def get_modes_slices( num_spatial_dims: int, num_points: int ) -> tuple[tuple[slice, ...], ...]: """ - Produces a list of list of slices corresponding to all positive and negative - wavenumber blocks found in the representation of a state in Fourier space. + Produces a tuple of tuple of slices corresponding to all positive and + negative wavenumber blocks found in the representation of a state in Fourier + space. + + **Arguments:** + + - `num_spatial_dims`: The number of spatial dimensions. + - `num_points`: The number of points in each spatial dimension. + + **Returns:** + + - `all_modes_slices`: The tuple of tuple of slices. The outer tuple has + `2^(D-1)` entries if `D` is the number of spatial dimensions. Each inner + tuple has `D+1` entries. + + !!! example + In 1D, there is only one block of coefficients in Fourier space; those + associated with the positive wavenumbers. The additional `slice(None)` + in the beginning is for the channel axis. + + ```python + + slices = exponax.spectral.get_modes_slices(1, 10) + + print(slices) + + # ( + + # (slice(None), slice(None, 6)), + + # ) + + ``` + + In 2D, there are two blocks of coefficients; one for the positive + wavenumbers and one for the negative wavenumbers along the first axis + (which cannot be halved because the `rfft` already acts on the last, the + second spatial axis). + + ```python + + slices = exponax.spectral.get_modes_slices(2, 10) + + print(slices) + + # ( + + # (slice(None), slice(None, 5), slice(None, 6)), + + # (slice(None), slice(-5, None), slice(None, 6)), + + # ) + + ``` """ is_even = num_points % 2 == 0 nyquist_mode = num_points // 2 @@ -462,27 +612,36 @@ def fft( """ Perform a **real-valued** FFT of a field. This function is designed for states in `Exponax` with a leading channel axis and then one, two, or three - following spatial axes, **each of the same length** N. + subsequent spatial axes, **each of the same length** N. Only accepts real-valued input fields and performs a real-valued FFT. Hence, the last axis of the returned field is of length N//2+1. + !!! warning + The argument `num_spatial_dims` can only be correctly inferred if the + array follows the Exponax convention, e.g., no leading batch axis. For a + batched operation, use `jax.vmap` on this function. + **Arguments:** - - `field`: The field to transform, shape `(C, ..., N,)`. - - `num_spatial_dims`: The number of spatial dimensions, i.e., how many - spatial axes follow the channel axis. Can be inferred from the array - if it follows the Exponax convention. For example, it is not allowed - to have a leading batch axis, in such a case use `jax.vmap` on this - function. + + - `field`: The state to transform. + - `num_spatial_dims`: The number of spatial dimensions, i.e., how many + spatial axes follow the channel axis. Can be inferred from the array if + it follows the Exponax convention. For example, it is not allowed to + have a leading batch axis, in such a case use `jax.vmap` on this + function. **Returns:** - - `field_hat`: The transformed field, shape `(C, ..., N//2+1)`. + + - `field_hat`: The transformed field, shape `(C, ..., N//2+1)`. !!! info Internally uses `jax.numpy.fft.rfftn` with the default settings for the `norm` argument with `norm="backward"`. This means that the forward FFT (this function) does not apply any normalization to the result, only the - [`exponax.ifft`][] function applies normalization. + [`exponax.ifft`][] function applies normalization. To extract the + amplitude of the coefficients divide by + `expoanx.spectral.build_scaling_array`. """ if num_spatial_dims is None: num_spatial_dims = field.ndim - 1 @@ -498,30 +657,38 @@ def ifft( ) -> Float[Array, "C ... N"]: """ Perform the inverse **real-valued** FFT of a field. This is the inverse - operation of `fft`. This function is designed for states in `Exponax` with a - leading channel axis and then one, two, or three following spatial axes. In - state space all spatial axes have the same length N (here called - `num_points`). + operation of `exponax.fft`. This function is designed for states in + `Exponax` with a leading channel axis and then one, two, or three following + spatial axes. In state space all spatial axes have the same length N (here + called `num_points`). Requires a complex-valued field in Fourier space with the last axis of length N//2+1. - The number of points (N, or `num_points`) must be provided if the number of - spatial dimensions is 1. Otherwise, it can be inferred from the shape of the - field. + !!! info + The number of points (N, or `num_points`) must be provided if the number + of spatial dimensions is 1. Otherwise, it can be inferred from the shape + of the field. + + !!! warning + The argument `num_spatial_dims` can only be correctly inferred if the + array follows the Exponax convention, e.g., no leading batch axis. For a + batched operation, use `jax.vmap` on this function. **Arguments:** - - `field_hat`: The transformed field, shape `(C, ..., N//2+1)`. - - `num_spatial_dims`: The number of spatial dimensions, i.e., how many - spatial axes follow the channel axis. Can be inferred from the array - if it follows the Exponax convention. For example, it is not allowed - to have a leading batch axis, in such a case use `jax.vmap` on this - function. - - `num_points`: The number of points in each spatial dimension. Can be - inferred if `num_spatial_dims` >= 2 + + - `field_hat`: The transformed field, shape `(C, ..., N//2+1)`. + - `num_spatial_dims`: The number of spatial dimensions, i.e., how many + spatial axes follow the channel axis. Can be inferred from the array if + it follows the Exponax convention. For example, it is not allowed to + have a leading batch axis, in such a case use `jax.vmap` on this + function. + - `num_points`: The number of points in each spatial dimension. Can be + inferred if `num_spatial_dims` >= 2 **Returns:** - - `field`: The transformed field, shape `(C, ..., N,)`. + + - `field`: The state in physical space, shape `(C, ..., N,)`. !!! info Internally uses `jax.numpy.fft.irfftn` with the default settings for the @@ -567,24 +734,32 @@ def derivative( the number of degrees of freedom N is even. For this, consider also using the order option. + !!! warning + The argument `num_spatial_dims` can only be correctly inferred if the + array follows the Exponax convention, e.g., no leading batch axis. For a + batched operation, use `jax.vmap` on this function. + **Arguments:** - - `field`: The field to differentiate, shape `(C, ..., N,)`. `C` can be - `1` for a scalar field or `D` for a vector field. - - `L`: The domain extent. - - `order`: The order of the derivative. Default is `1`. - - `indexing`: The indexing scheme to use for `jax.numpy.meshgrid`. - Either `"ij"` or `"xy"`. Default is `"ij"`. + + - `field`: The field to differentiate, shape `(C, ..., N,)`. `C` can be + `1` for a scalar field or `D` for a vector field. + - `domain_extent`: The size of the domain `L`; in higher dimensions + the domain is assumed to be a scaled hypercube `Ω = (0, L)ᵈ`. + - `order`: The order of the derivative. Default is `1`. + - `indexing`: The indexing scheme to use for `jax.numpy.meshgrid`. + Either `"ij"` or `"xy"`. Default is `"ij"`. **Returns:** - - `field_der`: The derivative of the field, shape `(C, D, ..., - (N//2)+1)` or `(D, ..., (N//2)+1)`. + + - `field_der`: The derivative of the field, shape `(C, D, ..., + (N//2)+1)` or `(D, ..., (N//2)+1)`. """ channel_shape = field.shape[0] spatial_shape = field.shape[1:] - D = len(spatial_shape) - N = spatial_shape[0] + num_spatial_dims = len(spatial_shape) + num_points = spatial_shape[0] derivative_operator = build_derivative_operator( - D, domain_extent, N, indexing=indexing + num_spatial_dims, domain_extent, num_points, indexing=indexing ) # # I decided to not use this fix @@ -594,7 +769,7 @@ def derivative( # ) derivative_operator_fixed = derivative_operator**order - field_hat = fft(field, num_spatial_dims=D) + field_hat = fft(field, num_spatial_dims=num_spatial_dims) if channel_shape == 1: # Do not introduce another channel axis field_der_hat = derivative_operator_fixed * field_hat @@ -602,7 +777,9 @@ def derivative( # Create a "derivative axis" right after the channel axis field_der_hat = field_hat[:, None] * derivative_operator_fixed[None, ...] - field_der = ifft(field_der_hat, num_spatial_dims=D, num_points=N) + field_der = ifft( + field_der_hat, num_spatial_dims=num_spatial_dims, num_points=num_points + ) return field_der @@ -612,6 +789,30 @@ def make_incompressible( *, indexing: str = "ij", ): + """ + Makes a velocity field incompressible by solving the associated pressure + Poisson equation and subtract the pressure gradient. + + With the divergence of the velocity field as the right-hand side, solve the + Poisson equation for pressure `p` + + Δp = - ∇ ⋅ v⃗ + + and then correct the velocity field to be incompressible + + v⃗ ← v⃗ - ∇p + + **Arguments:** + + - `field`: The velocity field to make incompressible, shape `(D, ..., N,)`. + Must have as many channel dimensions as spatial axes. + - `indexing`: The indexing scheme to use for `jax.numpy.meshgrid`. + + **Returns:** + + - `incompressible_field`: The incompressible velocity field, shape `(D, ..., + N,)`. + """ channel_shape = field.shape[0] spatial_shape = field.shape[1:] num_spatial_dims = len(spatial_shape) @@ -652,3 +853,87 @@ def make_incompressible( ) return incompressible_field + + +def get_spectrum( + state: Float[Array, "C ... N"], + *, + power: bool = True, +) -> Float[Array, "C (N//2)+1"]: + """ + Compute the Fourier spectrum of a state, either the power spectrum or the + amplitude spectrum. + + !!! info + The returned array will always have two axes, no matter how many spatial + axes the input has. + + **Arguments:** + + - `state`: The state to compute the spectrum of. The state must follow the + `Exponax` convention with a leading channel axis and then one, two, or + three subsequent spatial axes, **each of the same length** N. + - `power`: Whether to compute the power spectrum or the amplitude spectrum. + Default is `True` meaning the amplitude spectrum. + + **Returns:** + + - `spectrum`: The spectrum of the state, shape `(C, (N//2)+1)`. + + !!! tip + The spectrum is usually best presented with a logarithmic y-axis, either + as `plt.semiology` or `plt.loglog`. Sometimes it can be helpful to set + the spectrum below a threshold to zero to better visualize the relevant + parts of the spectrum. This can be done with `jnp.maximum(spectrum, + 1e-10)` for example. + + !!! info + If it is applied to a vorticity field with `power=True` (default), it + produces the enstrophy spectrum. + + !!! note + The binning in higher dimensions can sometimes be counterintuitive. For + example, on a 2D grid if mode `[2, 2]` is populated, this is not + represented in the 2-bin (i.e., when indexing the returning array of + this function at `[2]`), but in the 3-bin because its distance from the + center is `sqrt(2**2 + 2**2) = 2.8284...` which is not in the range of + the 2-bin `[1.5, 2.5)`. + """ + num_spatial_dims = state.ndim - 1 + num_points = state.shape[-1] + + state_hat = fft(state, num_spatial_dims=num_spatial_dims) + state_hat_scaled = state_hat / build_scaling_array( + num_spatial_dims, + num_points, + mode="reconstruction", # because of rfft + ) + + wavenumbers_mesh = build_wavenumbers(num_spatial_dims, num_points) + wavenumbers_1d = build_wavenumbers(1, num_points) + wavenumbers_norm = jnp.linalg.norm(wavenumbers_mesh, axis=0, keepdims=True) + + dk = wavenumbers_1d[0, 1] - wavenumbers_1d[0, 0] + + if power: + magnitude = 0.5 * jnp.abs(state_hat_scaled) ** 2 + else: + magnitude = jnp.abs(state_hat_scaled) + + spectrum = [] + + def power_in_bucket(p, k): + lower_limit = k - dk / 2 + upper_limit = k + dk / 2 + mask = (wavenumbers_norm[0] >= lower_limit) & ( + wavenumbers_norm[0] < upper_limit + ) + return jnp.sum(p[mask]) + + for k in wavenumbers_1d[0, :]: + spectrum.append(jax.vmap(power_in_bucket, in_axes=(0, None))(magnitude, k)) + + spectrum = jnp.stack(spectrum, axis=-1) + # spectrum /= jnp.sum(spectrum, axis=-1, keepdims=True) + + return spectrum diff --git a/exponax/ic/_gaussian_random_field.py b/exponax/ic/_gaussian_random_field.py index f3f8ba6..1af80c4 100644 --- a/exponax/ic/_gaussian_random_field.py +++ b/exponax/ic/_gaussian_random_field.py @@ -80,7 +80,9 @@ def __call__( noise = noise * amplitude - noise = noise * build_scaling_array(self.num_spatial_dims, num_points) + noise = noise * build_scaling_array( + self.num_spatial_dims, num_points, mode="coef_extraction" + ) ic = ifft(noise, num_spatial_dims=self.num_spatial_dims, num_points=num_points) diff --git a/exponax/ic/_truncated_fourier_series.py b/exponax/ic/_truncated_fourier_series.py index ee4aad0..8dedd86 100644 --- a/exponax/ic/_truncated_fourier_series.py +++ b/exponax/ic/_truncated_fourier_series.py @@ -132,7 +132,9 @@ def __call__( ) fourier_noise = fourier_noise * build_scaling_array( - self.num_spatial_dims, num_points + self.num_spatial_dims, + num_points, + mode="coef_extraction", ) u = ifft( diff --git a/exponax/nonlin_fun/_vorticity_convection.py b/exponax/nonlin_fun/_vorticity_convection.py index 558f56f..efab8dd 100644 --- a/exponax/nonlin_fun/_vorticity_convection.py +++ b/exponax/nonlin_fun/_vorticity_convection.py @@ -171,7 +171,7 @@ def __init__( # `injection_mode`, because we apply the forcing on the vorticity. -injection_mode * injection_scale - * build_scaling_array(num_spatial_dims, num_points), + * build_scaling_array(num_spatial_dims, num_points, mode="coef_extraction"), 0.0, ) diff --git a/tests/test_filter_masks.py b/tests/test_filter_masks.py new file mode 100644 index 0000000..2818566 --- /dev/null +++ b/tests/test_filter_masks.py @@ -0,0 +1,451 @@ +import numpy as np + +import exponax as ex + + +def test_low_pass_filter_masks_1d(): + # Need to test both for even and odd number of points because that changes + # how the Nyquist mode is treated. + np.testing.assert_equal( + ex.spectral.low_pass_filter_mask(1, 10, cutoff=3), + np.array([[True, True, True, True, False, False]]), + ) + np.testing.assert_equal( + ex.spectral.low_pass_filter_mask(1, 11, cutoff=3), + np.array([[True, True, True, True, False, False]]), + ) + + +def test_nyquist_filter_masks_1d(): + np.testing.assert_equal( + ex.spectral.oddball_filter_mask(1, 10), + np.array([[True, True, True, True, True, False]]), + ) + np.testing.assert_equal( + ex.spectral.oddball_filter_mask(1, 11), + np.array([[True, True, True, True, True, True]]), + ) + + +def test_low_pass_filter_masks_2d(): + np.testing.assert_equal( + ex.spectral.low_pass_filter_mask(2, 10, cutoff=3), + np.array( + [ + [ + [True, True, True, True, False, False], + [True, True, True, True, False, False], + [True, True, True, True, False, False], + [True, True, True, True, False, False], + [False, False, False, False, False, False], + [False, False, False, False, False, False], + [False, False, False, False, False, False], + [True, True, True, True, False, False], + [True, True, True, True, False, False], + [True, True, True, True, False, False], + ] + ] + ), + ) + np.testing.assert_equal( + ex.spectral.low_pass_filter_mask(2, 11, cutoff=3), + np.array( + [ + [ + [True, True, True, True, False, False], + [True, True, True, True, False, False], + [True, True, True, True, False, False], + [True, True, True, True, False, False], + [False, False, False, False, False, False], + [False, False, False, False, False, False], + [False, False, False, False, False, False], + [False, False, False, False, False, False], + [True, True, True, True, False, False], + [True, True, True, True, False, False], + [True, True, True, True, False, False], + ] + ] + ), + ) + # Below is with `axis_separate=False` which not creates `True`-hypercube + # regions but spheres (in 3d) or circles (in 2d) of `True` values. + np.testing.assert_equal( + ex.spectral.low_pass_filter_mask(2, 10, cutoff=3, axis_separate=False), + np.array( + [ + [ + [True, True, True, True, False, False], + [True, True, True, False, False, False], + [True, True, True, False, False, False], + [True, False, False, False, False, False], + [False, False, False, False, False, False], + [False, False, False, False, False, False], + [False, False, False, False, False, False], + [True, False, False, False, False, False], + [True, True, True, False, False, False], + [True, True, True, False, False, False], + ] + ] + ), + ) + np.testing.assert_equal( + ex.spectral.low_pass_filter_mask(2, 11, cutoff=3, axis_separate=False), + np.array( + [ + [ + [True, True, True, True, False, False], + [True, True, True, False, False, False], + [True, True, True, False, False, False], + [True, False, False, False, False, False], + [False, False, False, False, False, False], + [False, False, False, False, False, False], + [False, False, False, False, False, False], + [False, False, False, False, False, False], + [True, False, False, False, False, False], + [True, True, True, False, False, False], + [True, True, True, False, False, False], + ] + ] + ), + ) + + +def test_nyquist_filter_masks_2d(): + np.testing.assert_equal( + ex.spectral.oddball_filter_mask(2, 10), + np.array( + [ + [ + [True, True, True, True, True, False], + [True, True, True, True, True, False], + [True, True, True, True, True, False], + [True, True, True, True, True, False], + [True, True, True, True, True, False], + [False, False, False, False, False, False], + [True, True, True, True, True, False], + [True, True, True, True, True, False], + [True, True, True, True, True, False], + [True, True, True, True, True, False], + ] + ] + ), + ) + np.testing.assert_equal( + ex.spectral.oddball_filter_mask(2, 11), + np.array( + [ + [ + [True, True, True, True, True, True], + [True, True, True, True, True, True], + [True, True, True, True, True, True], + [True, True, True, True, True, True], + [True, True, True, True, True, True], + [True, True, True, True, True, True], + [True, True, True, True, True, True], + [True, True, True, True, True, True], + [True, True, True, True, True, True], + [True, True, True, True, True, True], + [True, True, True, True, True, True], + ] + ] + ), + ) + + +def test_low_pass_filter_masks_3d(): + np.testing.assert_equal( + ex.spectral.low_pass_filter_mask(3, 8, cutoff=2), + np.array( + [ + [ + [ + [True, True, True, False, False], + [True, True, True, False, False], + [True, True, True, False, False], + [False, False, False, False, False], + [False, False, False, False, False], + [False, False, False, False, False], + [True, True, True, False, False], + [True, True, True, False, False], + ], + [ + [True, True, True, False, False], + [True, True, True, False, False], + [True, True, True, False, False], + [False, False, False, False, False], + [False, False, False, False, False], + [False, False, False, False, False], + [True, True, True, False, False], + [True, True, True, False, False], + ], + [ + [True, True, True, False, False], + [True, True, True, False, False], + [True, True, True, False, False], + [False, False, False, False, False], + [False, False, False, False, False], + [False, False, False, False, False], + [True, True, True, False, False], + [True, True, True, False, False], + ], + [ + [False, False, False, False, False], + [False, False, False, False, False], + [False, False, False, False, False], + [False, False, False, False, False], + [False, False, False, False, False], + [False, False, False, False, False], + [False, False, False, False, False], + [False, False, False, False, False], + ], + [ + [False, False, False, False, False], + [False, False, False, False, False], + [False, False, False, False, False], + [False, False, False, False, False], + [False, False, False, False, False], + [False, False, False, False, False], + [False, False, False, False, False], + [False, False, False, False, False], + ], + [ + [False, False, False, False, False], + [False, False, False, False, False], + [False, False, False, False, False], + [False, False, False, False, False], + [False, False, False, False, False], + [False, False, False, False, False], + [False, False, False, False, False], + [False, False, False, False, False], + ], + [ + [True, True, True, False, False], + [True, True, True, False, False], + [True, True, True, False, False], + [False, False, False, False, False], + [False, False, False, False, False], + [False, False, False, False, False], + [True, True, True, False, False], + [True, True, True, False, False], + ], + [ + [True, True, True, False, False], + [True, True, True, False, False], + [True, True, True, False, False], + [False, False, False, False, False], + [False, False, False, False, False], + [False, False, False, False, False], + [True, True, True, False, False], + [True, True, True, False, False], + ], + ] + ] + ), + ) + np.testing.assert_equal( + ex.spectral.low_pass_filter_mask(3, 9, cutoff=2), + np.array( + [ + [ + [ + [True, True, True, False, False], + [True, True, True, False, False], + [True, True, True, False, False], + [False, False, False, False, False], + [False, False, False, False, False], + [False, False, False, False, False], + [False, False, False, False, False], + [True, True, True, False, False], + [True, True, True, False, False], + ], + [ + [True, True, True, False, False], + [True, True, True, False, False], + [True, True, True, False, False], + [False, False, False, False, False], + [False, False, False, False, False], + [False, False, False, False, False], + [False, False, False, False, False], + [True, True, True, False, False], + [True, True, True, False, False], + ], + [ + [True, True, True, False, False], + [True, True, True, False, False], + [True, True, True, False, False], + [False, False, False, False, False], + [False, False, False, False, False], + [False, False, False, False, False], + [False, False, False, False, False], + [True, True, True, False, False], + [True, True, True, False, False], + ], + [ + [False, False, False, False, False], + [False, False, False, False, False], + [False, False, False, False, False], + [False, False, False, False, False], + [False, False, False, False, False], + [False, False, False, False, False], + [False, False, False, False, False], + [False, False, False, False, False], + [False, False, False, False, False], + ], + [ + [False, False, False, False, False], + [False, False, False, False, False], + [False, False, False, False, False], + [False, False, False, False, False], + [False, False, False, False, False], + [False, False, False, False, False], + [False, False, False, False, False], + [False, False, False, False, False], + [False, False, False, False, False], + ], + [ + [False, False, False, False, False], + [False, False, False, False, False], + [False, False, False, False, False], + [False, False, False, False, False], + [False, False, False, False, False], + [False, False, False, False, False], + [False, False, False, False, False], + [False, False, False, False, False], + [False, False, False, False, False], + ], + [ + [False, False, False, False, False], + [False, False, False, False, False], + [False, False, False, False, False], + [False, False, False, False, False], + [False, False, False, False, False], + [False, False, False, False, False], + [False, False, False, False, False], + [False, False, False, False, False], + [False, False, False, False, False], + ], + [ + [True, True, True, False, False], + [True, True, True, False, False], + [True, True, True, False, False], + [False, False, False, False, False], + [False, False, False, False, False], + [False, False, False, False, False], + [False, False, False, False, False], + [True, True, True, False, False], + [True, True, True, False, False], + ], + [ + [True, True, True, False, False], + [True, True, True, False, False], + [True, True, True, False, False], + [False, False, False, False, False], + [False, False, False, False, False], + [False, False, False, False, False], + [False, False, False, False, False], + [True, True, True, False, False], + [True, True, True, False, False], + ], + ] + ] + ), + ) + + # TODO: Add tests for `axis_separate=False` in 3D. + + +def test_nyquist_filter_masks_3d(): + np.testing.assert_equal( + ex.spectral.oddball_filter_mask(3, 8), + np.array( + [ + [ + [ + [True, True, True, True, False], + [True, True, True, True, False], + [True, True, True, True, False], + [True, True, True, True, False], + [False, False, False, False, False], + [True, True, True, True, False], + [True, True, True, True, False], + [True, True, True, True, False], + ], + [ + [True, True, True, True, False], + [True, True, True, True, False], + [True, True, True, True, False], + [True, True, True, True, False], + [False, False, False, False, False], + [True, True, True, True, False], + [True, True, True, True, False], + [True, True, True, True, False], + ], + [ + [True, True, True, True, False], + [True, True, True, True, False], + [True, True, True, True, False], + [True, True, True, True, False], + [False, False, False, False, False], + [True, True, True, True, False], + [True, True, True, True, False], + [True, True, True, True, False], + ], + [ + [True, True, True, True, False], + [True, True, True, True, False], + [True, True, True, True, False], + [True, True, True, True, False], + [False, False, False, False, False], + [True, True, True, True, False], + [True, True, True, True, False], + [True, True, True, True, False], + ], + [ + [False, False, False, False, False], + [False, False, False, False, False], + [False, False, False, False, False], + [False, False, False, False, False], + [False, False, False, False, False], + [False, False, False, False, False], + [False, False, False, False, False], + [False, False, False, False, False], + ], + [ + [True, True, True, True, False], + [True, True, True, True, False], + [True, True, True, True, False], + [True, True, True, True, False], + [False, False, False, False, False], + [True, True, True, True, False], + [True, True, True, True, False], + [True, True, True, True, False], + ], + [ + [True, True, True, True, False], + [True, True, True, True, False], + [True, True, True, True, False], + [True, True, True, True, False], + [False, False, False, False, False], + [True, True, True, True, False], + [True, True, True, True, False], + [True, True, True, True, False], + ], + [ + [True, True, True, True, False], + [True, True, True, True, False], + [True, True, True, True, False], + [True, True, True, True, False], + [False, False, False, False, False], + [True, True, True, True, False], + [True, True, True, True, False], + [True, True, True, True, False], + ], + ] + ] + ), + ) + + np.testing.assert_equal( + ex.spectral.oddball_filter_mask(3, 9), + np.ones((1, 9, 9, (9 // 2) + 1), dtype=bool), + ) diff --git a/tests/test_shape_utilties.py b/tests/test_shape_utilties.py new file mode 100644 index 0000000..61b3d6a --- /dev/null +++ b/tests/test_shape_utilties.py @@ -0,0 +1,23 @@ +import exponax as ex + + +def test_space_indices(): + assert ex.spectral.space_indices(1) == (-1,) + assert ex.spectral.space_indices(2) == (-2, -1) + assert ex.spectral.space_indices(3) == (-3, -2, -1) + + +def test_spatial_shape(): + assert ex.spectral.spatial_shape(1, 64) == (64,) + assert ex.spectral.spatial_shape(2, 64) == (64, 64) + assert ex.spectral.spatial_shape(3, 64) == (64, 64, 64) + + +def test_wavenumber_shape(): + assert ex.spectral.wavenumber_shape(1, 64) == (33,) + assert ex.spectral.wavenumber_shape(2, 64) == (64, 33) + assert ex.spectral.wavenumber_shape(3, 64) == (64, 64, 33) + + assert ex.spectral.wavenumber_shape(1, 65) == (33,) + assert ex.spectral.wavenumber_shape(2, 65) == (65, 33) + assert ex.spectral.wavenumber_shape(3, 65) == (65, 65, 33) diff --git a/tests/test_spectral_scaling_arrays.py b/tests/test_spectral_scaling_arrays.py new file mode 100644 index 0000000..099b2e5 --- /dev/null +++ b/tests/test_spectral_scaling_arrays.py @@ -0,0 +1,215 @@ +import jax +import jax.numpy as jnp +import pytest + +import exponax as ex + + +@pytest.mark.parametrize( + "num_spatial_dims,num_points", [(D, N) for D in [1, 2, 3] for N in [10, 11]] +) +def test_building_scaling_array_for_norm_compensation( + num_spatial_dims: int, num_points: int +): + noise = jax.random.normal( + jax.random.PRNGKey(0), (1,) + (num_points,) * num_spatial_dims + ) + + noise_hat_norm_backward = jnp.fft.rfftn( + noise, + axes=ex.spectral.space_indices(num_spatial_dims), + ) + noise_hat_norm_forward = jnp.fft.rfftn( + noise, + axes=ex.spectral.space_indices(num_spatial_dims), + norm="forward", + ) + + scaling_array = ex.spectral.build_scaling_array( + num_spatial_dims, + num_points, + mode="norm_compensation", + ) + + noise_hat_norm_backward_scaled = noise_hat_norm_backward / scaling_array + + assert noise_hat_norm_backward_scaled == pytest.approx(noise_hat_norm_forward) + + +# Mode "reconstruction" is already tested as part of the `test_interpolation.py`` + + +def test_building_scaling_array_for_coef_extraction(): + # 1D + grid_1d = ex.make_grid(1, 2 * jnp.pi, 10) + + u = 3 * jnp.cos(2 * grid_1d) + u_hat = ex.fft(u) + u_hat_scaled = u_hat / ex.spectral.build_scaling_array( + 1, + 10, + mode="coef_extraction", + ) + assert u_hat_scaled.round(5)[0, 2] == pytest.approx(3.0 + 0.0j) + + u = 3.0 * jnp.ones_like(grid_1d) + u_hat = ex.fft(u) + u_hat_scaled = u_hat / ex.spectral.build_scaling_array( + 1, + 10, + mode="coef_extraction", + ) + assert u_hat_scaled.round(5)[0, 0] == pytest.approx(3.0 + 0.0j) + + u = 3.0 * jnp.sin(2 * grid_1d) + u_hat = ex.fft(u) + u_hat_scaled = u_hat / ex.spectral.build_scaling_array( + 1, + 10, + mode="coef_extraction", + ) + assert u_hat_scaled.round(5)[0, 2] == pytest.approx(0.0 - 3.0j) + + u = 3.0 * jnp.cos(5 * grid_1d) + u_hat = ex.fft(u) + u_hat_scaled = u_hat / ex.spectral.build_scaling_array( + 1, + 10, + mode="coef_extraction", + ) + assert u_hat_scaled.round(5)[0, 5] == pytest.approx(3.0 + 0.0j) + + u = 3.0 * jnp.sin(5 * grid_1d) + u_hat = ex.fft(u) + u_hat_scaled = u_hat / ex.spectral.build_scaling_array( + 1, + 10, + mode="coef_extraction", + ) + # Nyquist mode sine cannot be captured + assert u_hat_scaled.round(5)[0, 5] == pytest.approx(0.0 + 0.0j) + + # 2D - single terms + grid_2d = ex.make_grid(2, 2 * jnp.pi, 10) + + u = 3.0 * jnp.cos(2 * grid_2d[0:1]) + u_hat = ex.fft(u) + u_hat_scaled = u_hat / ex.spectral.build_scaling_array( + 2, + 10, + mode="coef_extraction", + ) + assert u_hat_scaled.round(5)[0, 2, 0] == pytest.approx(3.0 + 0.0j) + + u = 3.0 * jnp.cos(2 * grid_2d[1:2]) + u_hat = ex.fft(u) + u_hat_scaled = u_hat / ex.spectral.build_scaling_array( + 2, + 10, + mode="coef_extraction", + ) + assert u_hat_scaled.round(5)[0, 0, 2] == pytest.approx(3.0 + 0.0j) + + u = 3.0 * jnp.ones_like(grid_2d[0:1]) + u_hat = ex.fft(u) + u_hat_scaled = u_hat / ex.spectral.build_scaling_array( + 2, + 10, + mode="coef_extraction", + ) + assert u_hat_scaled.round(5)[0, 0, 0] == pytest.approx(3.0 + 0.0j) + + u = 3.0 * jnp.sin(2 * grid_2d[0:1]) + u_hat = ex.fft(u) + u_hat_scaled = u_hat / ex.spectral.build_scaling_array( + 2, + 10, + mode="coef_extraction", + ) + assert u_hat_scaled.round(5)[0, 2, 0] == pytest.approx(0.0 - 3.0j) + + u = 3.0 * jnp.sin(2 * grid_2d[1:2]) + u_hat = ex.fft(u) + u_hat_scaled = u_hat / ex.spectral.build_scaling_array( + 2, + 10, + mode="coef_extraction", + ) + assert u_hat_scaled.round(5)[0, 0, 2] == pytest.approx(0.0 - 3.0j) + + u = 3.0 * jnp.cos(5 * grid_2d[0:1]) + u_hat = ex.fft(u) + u_hat_scaled = u_hat / ex.spectral.build_scaling_array( + 2, + 10, + mode="coef_extraction", + ) + assert u_hat_scaled.round(5)[0, 5, 0] == pytest.approx(3.0 + 0.0j) + + u = 3.0 * jnp.cos(5 * grid_2d[1:2]) + u_hat = ex.fft(u) + u_hat_scaled = u_hat / ex.spectral.build_scaling_array( + 2, + 10, + mode="coef_extraction", + ) + assert u_hat_scaled.round(5)[0, 0, 5] == pytest.approx(3.0 + 0.0j) + + u = 3.0 * jnp.sin(5 * grid_2d[0:1]) + u_hat = ex.fft(u) + u_hat_scaled = u_hat / ex.spectral.build_scaling_array( + 2, + 10, + mode="coef_extraction", + ) + # Nyquist mode sine cannot be captured + assert u_hat_scaled.round(5)[0, 5, 0] == pytest.approx(0.0 + 0.0j) + + u = 3.0 * jnp.sin(5 * grid_2d[1:2]) + u_hat = ex.fft(u) + u_hat_scaled = u_hat / ex.spectral.build_scaling_array( + 2, + 10, + mode="coef_extraction", + ) + # Nyquist mode sine cannot be captured + assert u_hat_scaled.round(5)[0, 0, 5] == pytest.approx(0.0 + 0.0j) + + # 2D - mixed terms + u = 3.0 * jnp.cos(2 * grid_2d[0:1]) * jnp.cos(2 * grid_2d[1:2]) + u_hat = ex.fft(u) + u_hat_scaled = u_hat / ex.spectral.build_scaling_array( + 2, + 10, + mode="coef_extraction", + ) + assert u_hat_scaled.round(5)[0, 2, 2] == pytest.approx(3.0 + 0.0j) + + u = 3.0 * jnp.sin(2 * grid_2d[0:1]) * jnp.sin(2 * grid_2d[1:2]) + u_hat = ex.fft(u) + u_hat_scaled = u_hat / ex.spectral.build_scaling_array( + 2, + 10, + mode="coef_extraction", + ) + assert u_hat_scaled.round(5)[0, 2, 2] == pytest.approx(-3.0 + 0.0j) + + u = 3.0 * jnp.cos(2 * grid_2d[0:1]) * jnp.sin(2 * grid_2d[1:2]) + u_hat = ex.fft(u) + u_hat_scaled = u_hat / ex.spectral.build_scaling_array( + 2, + 10, + mode="coef_extraction", + ) + assert u_hat_scaled.round(5)[0, 2, 2] == pytest.approx(0.0 - 3.0j) + + u = 3.0 * jnp.sin(2 * grid_2d[0:1]) * jnp.cos(2 * grid_2d[1:2]) + u_hat = ex.fft(u) + u_hat_scaled = u_hat / ex.spectral.build_scaling_array( + 2, + 10, + mode="coef_extraction", + ) + assert u_hat_scaled.round(5)[0, 2, 2] == pytest.approx(0.0 - 3.0j) + + # TODO: 3D diff --git a/tests/test_spectrum.py b/tests/test_spectrum.py new file mode 100644 index 0000000..6bac997 --- /dev/null +++ b/tests/test_spectrum.py @@ -0,0 +1,66 @@ +import jax.numpy as jnp +import pytest + +import exponax as ex + + +def test_amplitude_spectrum(): + # 1D + grid_1d = ex.make_grid(1, 2 * jnp.pi, 128) + + u = 3.0 * jnp.sin(grid_1d) + spectrum = ex.get_spectrum(u, power=False) + assert spectrum[0, 1] == pytest.approx(3.0) + + u = 3.0 * jnp.cos(2 * grid_1d) + spectrum = ex.get_spectrum(u, power=False) + assert spectrum[0, 2] == pytest.approx(3.0) + + u = 3.0 * jnp.sin(3 * grid_1d) + 4.0 * jnp.cos(3 * grid_1d) + spectrum = ex.get_spectrum(u, power=False) + assert spectrum[0, 3] == pytest.approx(jnp.sqrt(3.0**2 + 4.0**2)) + + u = 3.0 * jnp.ones_like(grid_1d) + spectrum = ex.get_spectrum(u, power=False) + assert spectrum[0, 0] == pytest.approx(3.0) + + # 2D - single terms + grid_2d = ex.make_grid(2, 2 * jnp.pi, 48) + + u = 3.0 * jnp.sin(grid_2d[0:1]) + spectrum = ex.get_spectrum(u, power=False) + assert spectrum[0, 1] == pytest.approx(3.0) + + u = 3.0 * jnp.cos(2 * grid_2d[0:1]) + spectrum = ex.get_spectrum(u, power=False) + assert spectrum[0, 2] == pytest.approx(3.0) + + u = 3.0 * jnp.sin(3 * grid_2d[0:1]) + 4.0 * jnp.cos(3 * grid_2d[0:1]) + spectrum = ex.get_spectrum(u, power=False) + assert spectrum[0, 3] == pytest.approx(jnp.sqrt(3.0**2 + 4.0**2)) + + u = 3.0 * jnp.ones_like(grid_2d[0:1]) + spectrum = ex.get_spectrum(u, power=False) + assert spectrum[0, 0] == pytest.approx(3.0) + + u = 3.0 * jnp.sin(grid_2d[1:2]) + spectrum = ex.get_spectrum(u, power=False) + assert spectrum[0, 1] == pytest.approx(3.0) + + u = 3.0 * jnp.cos(2 * grid_2d[1:2]) + spectrum = ex.get_spectrum(u, power=False) + assert spectrum[0, 2] == pytest.approx(3.0) + + # 2D - mixed terms + u = 3.0 * jnp.sin(1 * grid_2d[0:1]) * jnp.cos(1 * grid_2d[1:2]) + spectrum = ex.get_spectrum(u, power=False) + assert spectrum[0, 1] == pytest.approx(3.0) + + u = 3.0 * jnp.sin(2 * grid_2d[0:1]) * jnp.cos(2 * grid_2d[1:2]) + spectrum = ex.get_spectrum(u, power=False) + # The amplitude is in the 3-bin because the wavenumber norm of [2, 2] is + # 2*sqrt(2) = 2.8284 which is not in the interval [1.5, 2.5). + assert spectrum[0, 3] == pytest.approx(3.0) + assert spectrum[0, 2] == pytest.approx(0.0, abs=1e-5) + + # TODO: 3D