Skip to content

Commit

Permalink
Generalize cumulative_trapezoid to fully support the scipy.integrate.…
Browse files Browse the repository at this point in the history
…cumulative_trapezoid API.

This also adds tests for previously untested codepaths, eg. when `x=None`.

Also fixes argument order: previous version used x before y, which is different from scipy. This could easily confuse users of this function.

PiperOrigin-RevId: 656395970
  • Loading branch information
sbodenstein authored and Torax team committed Jul 29, 2024
1 parent 097a493 commit ff3446f
Show file tree
Hide file tree
Showing 4 changed files with 88 additions and 39 deletions.
8 changes: 2 additions & 6 deletions torax/core_profile_setters.py
Original file line number Diff line number Diff line change
Expand Up @@ -503,9 +503,7 @@ def _update_psi_from_j(
y = currents.jtot_hires * geo.vpr_hires
assert y.ndim == 1
assert geo.r_hires.ndim == 1
integrated = math_utils.cumulative_trapezoid(
geo.r_hires, y, initial=jnp.zeros(())
)
integrated = math_utils.cumulative_trapezoid(y=y, x=geo.r_hires, initial=0.0)
scale = jnp.concatenate((
jnp.zeros((1,)),
(8 * jnp.pi**3 * constants.CONSTANTS.mu0 * geo.B0)
Expand All @@ -516,9 +514,7 @@ def _update_psi_from_j(

# psi on cell grid
psi_hires = math_utils.cumulative_trapezoid(
geo.r_hires,
dpsi_dr_hires,
initial=jnp.zeros(()),
y=dpsi_dr_hires, x=geo.r_hires, initial=0.0
)

psi_value = jnp.interp(geo.r, geo.r_hires, psi_hires)
Expand Down
68 changes: 53 additions & 15 deletions torax/math_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
Math operations that are needed for Torax, but are not specific to plasma
physics or differential equation solvers.
"""
from typing import Optional
import functools
import jax
from jax import numpy as jnp

Expand All @@ -40,30 +40,68 @@ def tridiag(


def cumulative_trapezoid(
x: jax.Array, y: jax.Array, initial: Optional[jax.Array] = None
y: jax.Array,
x: jax.Array | None = None,
dx: float = 1.0,
axis: int = -1,
initial: float | None = None,
) -> jax.Array:
"""Cumulatively integrate y = f(x) using the trapezoid rule.
Jax equivalent of scipy.integrate.cumulative_trapezoid.
without as much support for different shapes / options as the scipy version.
JAX equivalent of scipy.integrate.cumulative_trapezoid.
Args:
x: 1-D array
y: 1-D array
initial: Optional array containing a single value. If specified, out[i] =
trapz(y[:i +1], x[:i + 1]), with out[0] = initial. Usually initial should
be 0 in this case. If left unspecified, the leftmost output, corresponding
to summing no terms, is omitted.
y: array of data to integrate.
x: optional array of sample points corresponding to the `y` values. If not
provided, `x` defaults to equally spaced with spacing given by `dx`.
dx: the spacing between sample points when `x` is None (default: 1.0).
axis: the axis along which to integrate (default: -1)
initial: a scalar value to prepend to the result. Either None (default) or
0.0. If `initial=0`, the result is an array with the same shape as `y`. If
``initial=None``, the resulting array has one fewer elements than `y`
along the `axis` dimension.
Returns:
out: 1-D array of same shape, containing the cumulative integration by
trapezoid rule.
The cumulative definite integral approximated by the trapezoidal rule.
"""

d = jnp.diff(x)
out = jnp.cumsum(d * (y[1:] + y[:-1])) / 2.0
if x is None:
dx = jnp.asarray(dx, dtype=y.dtype)
else:
if x.ndim == 1:
if y.shape[axis] != len(x):
raise ValueError(
f'The length of x is {len(x)}, but expected {y.shape[axis]}.'
)
else:
if x.shape != y.shape:
raise ValueError(
'If x is not 1 dimensional, it must have the same shape as y.'
)

if x.ndim == 1:
dx = jnp.diff(x)
new_shape = [1] * y.ndim
new_shape[axis] = len(dx)
dx = jnp.reshape(dx, new_shape)
else:
dx = jnp.diff(x, axis=axis)

y_sliced = functools.partial(jax.lax.slice_in_dim, y, axis=axis)

out = jnp.cumsum(dx * (y_sliced(1, None) + y_sliced(0, -1)), axis=axis) / 2.0

if initial is not None:
out = jnp.concatenate((jnp.expand_dims(initial, 0), out))
if initial != 0.0:
raise ValueError(
'`initial` must be 0 or None. Non-zero values have been deprecated'
' since SciPy version 1.12.0.'
)
initial_array = jnp.asarray(initial, dtype=out.dtype)
initial_shape = list(out.shape)
initial_shape[axis] = 1
initial_array = jnp.broadcast_to(initial_array, initial_shape)
out = jnp.concatenate((initial_array, out), axis=axis)
return out


Expand Down
2 changes: 1 addition & 1 deletion torax/physics.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,7 +238,7 @@ def calc_s_from_psi(

# Volume on face grid
# pylint:disable=invalid-name
V = math_utils.cumulative_trapezoid(geo.r_face, geo.vpr_face)
V = math_utils.cumulative_trapezoid(y=geo.vpr_face, x=geo.r_face)

s = (
2
Expand Down
49 changes: 32 additions & 17 deletions torax/tests/math_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,38 +16,53 @@
from absl.testing import absltest
from absl.testing import parameterized
import jax
import jax.numpy as jnp
import numpy as np
import scipy.integrate
from torax import math_utils

jax.config.update('jax_enable_x64', True)


class MathUtilsTest(parameterized.TestCase):
"""Unit tests for the `torax.math_utils` module."""

@parameterized.parameters([
dict(seed=20221007, initial=None),
dict(seed=20221007, initial=0.0),
dict(seed=20221007, initial=1.0),
])
def test_cumulative_trapz(self, seed, initial):
@parameterized.product(
initial=(None, 0.),
axis=(-1, 1, -1),
array_x=(False, True),
dtype=(jnp.float32, jnp.float64),
shape=((13,), (2, 4, 1, 3)),
)
def test_cumulative_trapezoid(self, axis, array_x, initial, dtype, shape):
"""Test that cumulative_trapezoid matches the scipy implementation."""
rng_state = jax.random.PRNGKey(seed)
del seed # Make sure seed isn't accidentally re-used
rng_state = jax.random.PRNGKey(20221007)
rng_use_y, rng_use_x, _ = jax.random.split(rng_state, 3)

rng_use_dim, rng_use_y, rng_use_x, _ = jax.random.split(
rng_state, 4
)
dim = int(jax.random.randint(rng_use_dim, (1,), 1, 100)[0])
y = jax.random.normal(rng_use_y, (dim,))
if axis == 1 and len(shape) == 1:
self.skipTest('Axis out of range.')

dx = 0.754
y = jax.random.normal(rng_use_y, shape=shape, dtype=dtype)
del rng_use_y # Make sure rng_use_y isn't accidentally re-used
x = jax.random.normal(rng_use_x, (dim,))
if array_x:
x = jax.random.normal(rng_use_x, (shape[axis],), dtype=dtype)
else:
x = None
del rng_use_x # Make sure rng_use_x isn't accidentally re-used

cumulative = math_utils.cumulative_trapezoid(x, y, initial=initial)
cumulative = math_utils.cumulative_trapezoid(
y, x, dx=dx, axis=axis, initial=initial
)

self.assertEqual(cumulative.dtype, y.dtype)

ref = scipy.integrate.cumulative_trapezoid(y, x, initial=initial)
ref = scipy.integrate.cumulative_trapezoid(
y, x, dx=dx, axis=axis, initial=initial
)

np.testing.assert_allclose(cumulative, ref)
atol = 3e-7 if dtype == jnp.float32 else 1e-12
np.testing.assert_allclose(cumulative, ref, atol=atol)


if __name__ == '__main__':
Expand Down

0 comments on commit ff3446f

Please sign in to comment.