Skip to content

Commit

Permalink
Spectral updates (#35)
Browse files Browse the repository at this point in the history
* Preliminary function to compute the spectrum

* Potential improvement to spectrum computation

* Add hint on enstrophy

* Fix on wavenumber norm computation

* Rename to better reflect its effect on oddball mode

* Add tests on filter mask

* Improve first half of spectral documentation

* Merge common subexpressions

* Extend and fix documentation

* Add experimental convinience function to extract the Fourier coefficients

* Unify the creation of scaling arrays

* Test scaling array for norm_compensation

* Add tests for coefficient extraction in 1D

* Test for coefficient extration in 2D

* Enhance docstring

* Remove experimental function

* Fix spectrum function

* Export get_spectrum instead of make_incompressible

* Adapt docs

* Use proper links

* Enhance docstring

* Enhance docstring

* Test the spectrum creation

* Restructure docstring

* Tests helper utilities for fft setup
  • Loading branch information
Ceyron authored Sep 5, 2024
1 parent 7ade794 commit 379e794
Show file tree
Hide file tree
Showing 12 changed files with 1,257 additions and 201 deletions.
2 changes: 1 addition & 1 deletion docs/api/utilities/derivatives.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,4 @@

---

::: exponax.make_incompressible
::: exponax.spectral.make_incompressible
10 changes: 9 additions & 1 deletion docs/api/utilities/spectral.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,12 @@

---

::: exponax.ifft
::: exponax.ifft

---

::: exponax.get_spectrum

---

::: exponax.spectral.build_scaling_array
4 changes: 2 additions & 2 deletions exponax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from ._forced_stepper import ForcedStepper
from ._interpolation import FourierInterpolator, map_between_resolutions
from ._repeated_stepper import RepeatedStepper
from ._spectral import derivative, fft, ifft, make_incompressible
from ._spectral import derivative, fft, get_spectrum, ifft
from ._utils import (
build_ic_set,
make_grid,
Expand All @@ -26,7 +26,7 @@
"derivative",
"fft",
"ifft",
"make_incompressible",
"get_spectrum",
"make_grid",
"rollout",
"repeat",
Expand Down
16 changes: 10 additions & 6 deletions exponax/_interpolation.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,12 @@
from jaxtyping import Array, Complex, Float

from ._spectral import (
build_reconstructional_scaling_array,
build_scaled_wavenumbers,
build_scaling_array,
fft,
get_modes_slices,
ifft,
nyquist_filter_mask,
oddball_filter_mask,
space_indices,
wavenumber_shape,
)
Expand Down Expand Up @@ -84,8 +83,11 @@ def __init__(
self.num_points = state.shape[-1]

self.state_hat_scaled = fft(state, num_spatial_dims=self.num_spatial_dims) / (
build_reconstructional_scaling_array(
self.num_spatial_dims, self.num_points, indexing=indexing
build_scaling_array(
self.num_spatial_dims,
self.num_points,
mode="reconstruction",
indexing=indexing,
)
)
self.wavenumbers = build_scaled_wavenumbers(
Expand Down Expand Up @@ -242,12 +244,13 @@ def map_between_resolutions(
) / build_scaling_array(
num_spatial_dims,
old_num_points,
mode="norm_compensation",
)

if new_num_points > old_num_points:
# Upscaling
if old_num_points % 2 == 0 and oddball_zero:
old_state_hat_scaled *= nyquist_filter_mask(
old_state_hat_scaled *= oddball_filter_mask(
num_spatial_dims, old_num_points
)

Expand All @@ -269,11 +272,12 @@ def map_between_resolutions(
new_state_hat = new_state_hat_scaled * build_scaling_array(
num_spatial_dims,
new_num_points,
mode="norm_compensation",
)
if old_num_points > new_num_points:
# Downscaling
if new_num_points % 2 == 0 and oddball_zero:
new_state_hat *= nyquist_filter_mask(num_spatial_dims, new_num_points)
new_state_hat *= oddball_filter_mask(num_spatial_dims, new_num_points)

new_state = ifft(
new_state_hat,
Expand Down
Loading

0 comments on commit 379e794

Please sign in to comment.