diff --git a/docs/configuration.rst b/docs/configuration.rst index daf334a1..433d4f58 100644 --- a/docs/configuration.rst +++ b/docs/configuration.rst @@ -51,7 +51,7 @@ The following inputs are valid for **time-varying-scalar** parameters: * Single integer, float, or boolean. The parameter is then not time dependent * A time-series dict with ``{time: value}`` pairs, using the default ``interpolation_mode='PIECEWISE_LINEAR'``. -* A tuple with ``(dict, str)`` corresponding to ``(time-series, interpolation_mode)``. +* A tuple with ``(time-series, value-series)``. The time-series is a 1D array of times, and the value-series is a 1D array of values and the dimensions of both must match. * A ``xarray.DataArray`` with a single coordinate and a 1D value array. Examples: @@ -81,7 +81,8 @@ To extend configuration parameters where time-dependence is not enabled, to have Time-varying arrays ------------------- -Time-varying arrays can be defined using either primitives or an ``xarray.DataArray``. +Time-varying arrays can be defined using either primitives, an +``xarray.DataArray`` or a ``dict`` of ``Array``. Using primitives ^^^^^^^^^^^^^^^^ @@ -138,6 +139,14 @@ If a ``xarray.DataArray`` is specified then it is expected to have a ``time`` and ``rho_norm`` coordinate. The values of the data array are the values at each time and rho_norm. +Using ``dict`` of ``Array`` +^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +If a ``dict`` of ``Array`` is used, the dict must have keys of ``time``, +``rho_norm`` and ``value``. + +``time`` and ``rho_norm`` are expected to map to 1D array values. +``value`` is expected to map to a 2D array with shape ``(len(time), len(rho_norm))``. + .. _config_details: Detailed configuration structure diff --git a/torax/interpolated_param.py b/torax/interpolated_param.py index a730f284..f44904b1 100644 --- a/torax/interpolated_param.py +++ b/torax/interpolated_param.py @@ -17,6 +17,7 @@ import abc from collections.abc import Mapping import enum + import chex import jax import jax.numpy as jnp @@ -160,6 +161,7 @@ def get_value( Mapping[float, InterpolatedVarSingleAxisInput] | float | xr.DataArray + | dict[str, chex.Array] ) @@ -301,7 +303,8 @@ def is_bool_param(self) -> bool: class InterpolatedVarTimeRho(InterpolatedParamBase): """Interpolates on a grid (time, rho). - This class is initialised with `values`, either primitives or an xr.DataArray. + This class is initialised with `values`, either primitives, an xr.DataArray or + a a dict of `Array`s. If primitives are used, then `values` is expected as a mapping from time-values to `InterpolatedVarSingleAxis`s that tell you how to interpolate @@ -316,6 +319,11 @@ class InterpolatedVarTimeRho(InterpolatedParamBase): If you only want to use a subset of the xr.DataArray, filter the data array beforehand, e.g. `values=array.sel(time=[0.0, 2.0])` + If a dict of `Array`s is used, The dict is expected to have + ['time', 'rho_norm', 'value'] keys. The `time` and `rho_norm` are expected to + be 1D arrays and `value` is expected to be a 2D array with shape (len(time), + len(rho)). + This class linearly interpolates along time to provide a value at any (time, rho) pair. For time values that are outside the range of `values` the closest defined `InterpolatedVarSingleAxis` is used. @@ -324,6 +332,21 @@ class InterpolatedVarTimeRho(InterpolatedParamBase): at init and take just time at get_value. """ + def _load_from_arrays( + self, + arrays: dict[str, chex.Array], + rho_interpolation_mode: InterpolationMode, + ): + """Loads the data from numpy arrays.""" + self.times_values = { + t: InterpolatedVarSingleAxis( + (arrays['rho_norm'], arrays['value'][i, :]), + rho_interpolation_mode, + ) + for i, t in enumerate(arrays['time']) + } + self.sorted_indices = jnp.array(sorted(arrays['time'])) + def _load_from_xr_array( self, array: xr.DataArray, @@ -336,10 +359,7 @@ def _load_from_xr_array( raise ValueError('"rho_norm" must be a coordinate in given dataset.') self.times_values = { t: InterpolatedVarSingleAxis( - ( - array.rho_norm.data, - array.sel(time=t).values, - ), + (array.rho_norm.data, array.sel(time=t).values,), rho_interpolation_mode, ) for t in array.time.data @@ -386,6 +406,10 @@ def __init__( self._rho = rho if isinstance(values, xr.DataArray): self._load_from_xr_array(values, rho_interpolation_mode) + elif isinstance(values, Mapping) and all( + isinstance(v, chex.Array) for v in values.values() + ): + self._load_from_arrays(values, rho_interpolation_mode) else: self._load_from_primitives(values, rho_interpolation_mode) diff --git a/torax/tests/interpolated_param.py b/torax/tests/interpolated_param.py index a6d12778..04013dfa 100644 --- a/torax/tests/interpolated_param.py +++ b/torax/tests/interpolated_param.py @@ -429,7 +429,39 @@ def test_interpolated_var_time_rho_parses_xr_array_input(self): coords={'time': [0.0, 1.0], 'rho_norm': [0.25, 0.5, 0.75]}, ) interpolated_var_time_rho = interpolated_param.InterpolatedVarTimeRho( - values=array, rho=np.array([0.25, 0.5, 0.75]), + values=array, + rho=np.array([0.25, 0.5, 0.75]), + ) + + np.testing.assert_allclose( + interpolated_var_time_rho.get_value( + x=0.0, + ), + np.array([1.0, 2.0, 3.0]), + ) + np.testing.assert_allclose( + interpolated_var_time_rho.get_value( + x=1.0, + ), + np.array([4.0, 5.0, 6.0]), + ) + np.testing.assert_allclose( + interpolated_var_time_rho.get_value( + x=0.5, + ), + np.array([2.5, 3.5, 4.5]), + ) + + def test_interpolated_var_time_rho_parses_dict_array_input(self): + """Tests that InterpolatedVarTimeRho parses TimeRhoValueArray inputs correctly.""" + arrays = dict( + time=np.array([0.0, 1.0]), + rho_norm=np.array([0.25, 0.5, 0.75]), + value=np.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]), + ) + interpolated_var_time_rho = interpolated_param.InterpolatedVarTimeRho( + values=arrays, + rho=np.array([0.25, 0.5, 0.75]), ) np.testing.assert_allclose(