Skip to content

Commit

Permalink
Wrap FFT for specific array structures in Exponax (#4)
Browse files Browse the repository at this point in the history
* Add custom wrapped FFT calls

* Rewrite derivative routines to use wrapped FFT calls

* Make number of spatial dims optional

* Run black

* Correctly declare as optional

* Use wrapped routines

* Use wrapped routines

* Use wrapped routines

* Use wrapped routines

* Use wrapped routines

* Use wrapped routines

* Export the spectral submodule

* Add more details on usage

* Start documentation of spectral module
  • Loading branch information
Ceyron authored Sep 2, 2024
1 parent b253ea6 commit 545fd3f
Show file tree
Hide file tree
Showing 10 changed files with 239 additions and 142 deletions.
7 changes: 7 additions & 0 deletions docs/api/utilities/spectral.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
# Fourier-spectral utilities

::: exponax.fft

---

::: exponax.ifft
6 changes: 5 additions & 1 deletion exponax/__init__.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
from . import _metrics as metrics
from . import _poisson as poisson
from . import _spectral as spectral
from . import etdrk, ic, nonlin_fun, normalized, reaction, stepper, viz
from ._base_stepper import BaseStepper
from ._forced_stepper import ForcedStepper
from ._repeated_stepper import RepeatedStepper
from ._spectral import derivative, make_incompressible
from ._spectral import derivative, fft, ifft, make_incompressible
from ._utils import (
build_ic_set,
make_grid,
Expand All @@ -22,6 +23,8 @@
"poisson",
"RepeatedStepper",
"derivative",
"fft",
"ifft",
"make_incompressible",
"make_grid",
"rollout",
Expand All @@ -37,4 +40,5 @@
"reaction",
"stepper",
"viz",
"spectral",
]
12 changes: 6 additions & 6 deletions exponax/_base_stepper.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
from abc import ABC, abstractmethod

import equinox as eqx
import jax.numpy as jnp
from jaxtyping import Array, Complex, Float

from ._spectral import (
build_derivative_operator,
space_indices,
fft,
ifft,
spatial_shape,
wavenumber_shape,
)
Expand Down Expand Up @@ -202,12 +202,12 @@ def step(self, u: Float[Array, "C ... N"]) -> Float[Array, "C ... N"]:
**Returns:**
- `u_next`: The state vector after one step, shape `(C, ..., N,)`.
"""
u_hat = jnp.fft.rfftn(u, axes=space_indices(self.num_spatial_dims))
u_hat = fft(u, num_spatial_dims=self.num_spatial_dims)
u_next_hat = self.step_fourier(u_hat)
u_next = jnp.fft.irfftn(
u_next = ifft(
u_next_hat,
s=spatial_shape(self.num_spatial_dims, self.num_points),
axes=space_indices(self.num_spatial_dims),
num_spatial_dims=self.num_spatial_dims,
num_points=self.num_points,
)
return u_next

Expand Down
6 changes: 3 additions & 3 deletions exponax/_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import jax.numpy as jnp
from jaxtyping import Array, Float

from ._spectral import low_pass_filter_mask, space_indices
from ._spectral import fft, low_pass_filter_mask


def _MSE(
Expand Down Expand Up @@ -505,8 +505,8 @@ def _fourier_nRMSE(

mask = jnp.invert(low_mask) & high_mask

u_pred_fft = jnp.fft.rfftn(u_pred, axes=space_indices(num_spatial_dims))
u_ref_fft = jnp.fft.rfftn(u_ref, axes=space_indices(num_spatial_dims))
u_pred_fft = fft(u_pred, num_spatial_dims=num_spatial_dims)
u_ref_fft = fft(u_ref, num_spatial_dims=num_spatial_dims)

# The FFT incurse rounding errors around the machine precision that can be
# noticeable in the nRMSE. We will zero out the values that are smaller than
Expand Down
11 changes: 6 additions & 5 deletions exponax/_poisson.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@
from ._spectral import (
build_derivative_operator,
build_laplace_operator,
space_indices,
fft,
ifft,
spatial_shape,
)

Expand Down Expand Up @@ -90,12 +91,12 @@ def step(
**Returns:**
- `u`: The solution.
"""
f_hat = jnp.fft.rfftn(f, axes=space_indices(self.num_spatial_dims))
f_hat = fft(f, num_spatial_dims=self.num_spatial_dims)
u_hat = self.step_fourier(f_hat)
u = jnp.fft.irfftn(
u = ifft(
u_hat,
axes=space_indices(self.num_spatial_dims),
s=spatial_shape(self.num_spatial_dims, self.num_points),
num_spatial_dims=self.num_spatial_dims,
num_points=self.num_points,
)
return u

Expand Down
Loading

0 comments on commit 545fd3f

Please sign in to comment.