Skip to content

Commit

Permalink
Add NumPy/array loading support for interpolated params.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 655107377
  • Loading branch information
Nush395 authored and Torax team committed Jul 29, 2024
1 parent 097a493 commit aa4cb31
Show file tree
Hide file tree
Showing 3 changed files with 73 additions and 8 deletions.
13 changes: 11 additions & 2 deletions docs/configuration.rst
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ The following inputs are valid for **time-varying-scalar** parameters:

* Single integer, float, or boolean. The parameter is then not time dependent
* A time-series dict with ``{time: value}`` pairs, using the default ``interpolation_mode='PIECEWISE_LINEAR'``.
* A tuple with ``(dict, str)`` corresponding to ``(time-series, interpolation_mode)``.
* A tuple with ``(time-series, value-series)``. The time-series is a 1D array of times, and the value-series is a 1D array of values and the dimensions of both must match.
* A ``xarray.DataArray`` with a single coordinate and a 1D value array.

Examples:
Expand Down Expand Up @@ -81,7 +81,8 @@ To extend configuration parameters where time-dependence is not enabled, to have

Time-varying arrays
-------------------
Time-varying arrays can be defined using either primitives or an ``xarray.DataArray``.
Time-varying arrays can be defined using either primitives, an
``xarray.DataArray`` or a ``dict`` of ``Array``.

Using primitives
^^^^^^^^^^^^^^^^
Expand Down Expand Up @@ -138,6 +139,14 @@ If a ``xarray.DataArray`` is specified then it is expected to have a
``time`` and ``rho_norm`` coordinate. The values of the data array are the values
at each time and rho_norm.

Using ``dict`` of ``Array``
^^^^^^^^^^^^^^^^^^^^^^^^^^^^
If a ``dict`` of ``Array`` is used, the dict must have keys of ``time``,
``rho_norm`` and ``value``.

``time`` and ``rho_norm`` are expected to map to 1D array values.
``value`` is expected to map to a 2D array with shape ``(len(time), len(rho_norm))``.

.. _config_details:

Detailed configuration structure
Expand Down
34 changes: 29 additions & 5 deletions torax/interpolated_param.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import abc
from collections.abc import Mapping
import enum

import chex
import jax
import jax.numpy as jnp
Expand Down Expand Up @@ -160,6 +161,7 @@ def get_value(
Mapping[float, InterpolatedVarSingleAxisInput]
| float
| xr.DataArray
| dict[str, chex.Array]
)


Expand Down Expand Up @@ -301,7 +303,8 @@ def is_bool_param(self) -> bool:
class InterpolatedVarTimeRho(InterpolatedParamBase):
"""Interpolates on a grid (time, rho).
This class is initialised with `values`, either primitives or an xr.DataArray.
This class is initialised with `values`, either primitives, an xr.DataArray or
a a dict of `Array`s.
If primitives are used, then `values` is expected as a mapping from
time-values to `InterpolatedVarSingleAxis`s that tell you how to interpolate
Expand All @@ -316,6 +319,11 @@ class InterpolatedVarTimeRho(InterpolatedParamBase):
If you only want to use a subset of the xr.DataArray, filter the data array
beforehand, e.g. `values=array.sel(time=[0.0, 2.0])`
If a dict of `Array`s is used, The dict is expected to have
['time', 'rho_norm', 'value'] keys. The `time` and `rho_norm` are expected to
be 1D arrays and `value` is expected to be a 2D array with shape (len(time),
len(rho)).
This class linearly interpolates along time to provide a value at any
(time, rho) pair. For time values that are outside the range of `values` the
closest defined `InterpolatedVarSingleAxis` is used.
Expand All @@ -324,6 +332,21 @@ class InterpolatedVarTimeRho(InterpolatedParamBase):
at init and take just time at get_value.
"""

def _load_from_arrays(
self,
arrays: dict[str, chex.Array],
rho_interpolation_mode: InterpolationMode,
):
"""Loads the data from numpy arrays."""
self.times_values = {
t: InterpolatedVarSingleAxis(
(arrays['rho_norm'], arrays['value'][i, :]),
rho_interpolation_mode,
)
for i, t in enumerate(arrays['time'])
}
self.sorted_indices = jnp.array(sorted(arrays['time']))

def _load_from_xr_array(
self,
array: xr.DataArray,
Expand All @@ -336,10 +359,7 @@ def _load_from_xr_array(
raise ValueError('"rho_norm" must be a coordinate in given dataset.')
self.times_values = {
t: InterpolatedVarSingleAxis(
(
array.rho_norm.data,
array.sel(time=t).values,
),
(array.rho_norm.data, array.sel(time=t).values,),
rho_interpolation_mode,
)
for t in array.time.data
Expand Down Expand Up @@ -386,6 +406,10 @@ def __init__(
self._rho = rho
if isinstance(values, xr.DataArray):
self._load_from_xr_array(values, rho_interpolation_mode)
elif isinstance(values, Mapping) and all(
isinstance(v, chex.Array) for v in values.values()
):
self._load_from_arrays(values, rho_interpolation_mode)
else:
self._load_from_primitives(values, rho_interpolation_mode)

Expand Down
34 changes: 33 additions & 1 deletion torax/tests/interpolated_param.py
Original file line number Diff line number Diff line change
Expand Up @@ -429,7 +429,39 @@ def test_interpolated_var_time_rho_parses_xr_array_input(self):
coords={'time': [0.0, 1.0], 'rho_norm': [0.25, 0.5, 0.75]},
)
interpolated_var_time_rho = interpolated_param.InterpolatedVarTimeRho(
values=array, rho=np.array([0.25, 0.5, 0.75]),
values=array,
rho=np.array([0.25, 0.5, 0.75]),
)

np.testing.assert_allclose(
interpolated_var_time_rho.get_value(
x=0.0,
),
np.array([1.0, 2.0, 3.0]),
)
np.testing.assert_allclose(
interpolated_var_time_rho.get_value(
x=1.0,
),
np.array([4.0, 5.0, 6.0]),
)
np.testing.assert_allclose(
interpolated_var_time_rho.get_value(
x=0.5,
),
np.array([2.5, 3.5, 4.5]),
)

def test_interpolated_var_time_rho_parses_dict_array_input(self):
"""Tests that InterpolatedVarTimeRho parses TimeRhoValueArray inputs correctly."""
arrays = dict(
time=np.array([0.0, 1.0]),
rho_norm=np.array([0.25, 0.5, 0.75]),
value=np.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]),
)
interpolated_var_time_rho = interpolated_param.InterpolatedVarTimeRho(
values=arrays,
rho=np.array([0.25, 0.5, 0.75]),
)

np.testing.assert_allclose(
Expand Down

0 comments on commit aa4cb31

Please sign in to comment.