Skip to content

Commit

Permalink
Move Grid1D to interpolated_param_2d. This allows Grid1D to be a fiel…
Browse files Browse the repository at this point in the history
…d of TimeVaryingArray.

A followup CL will finish the refactor where face_centers and cell_centers are properties of the model, and the model is purely defined by nx and dx.

PiperOrigin-RevId: 735714390
  • Loading branch information
sbodenstein authored and Torax team committed Mar 11, 2025
1 parent c53225b commit 0b1bbd2
Show file tree
Hide file tree
Showing 41 changed files with 253 additions and 245 deletions.
3 changes: 2 additions & 1 deletion docs/model_integration.rst
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ parameters format and what is supported by TORAX.)
from torax import state
from torax.config import runtime_params_slice
from torax.geometry import geometry
from torax.torax_pydantic import torax_pydantic
from torax.sources import runtime_params as runtime_params_lib
# This inherits from the default source runtime parameters.
Expand All @@ -86,7 +87,7 @@ parameters format and what is supported by TORAX.)
def make_provider(
self,
torax_mesh: geometry.Grid1D | None = None,
torax_mesh: torax_pydantic.Grid1D | None = None,
) -> RuntimeParamsProvider:
return RuntimeParamsProvider(**self.get_provider_kwargs(torax_mesh))
Expand Down
8 changes: 4 additions & 4 deletions torax/config/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
import chex
from torax import interpolated_param
from torax.config import config_args
from torax.geometry import geometry
from torax.torax_pydantic import torax_pydantic

DynamicT = TypeVar('DynamicT')
ProviderT = TypeVar('ProviderT', bound='RuntimeParametersProvider')
Expand All @@ -36,7 +36,7 @@ class GridType(enum.Enum):
CELL = enum.auto()
FACE = enum.auto()

def get_mesh(self, torax_mesh: geometry.Grid1D) -> chex.Array:
def get_mesh(self, torax_mesh: torax_pydantic.Grid1D) -> chex.Array:
match self:
case GridType.CELL:
return torax_mesh.cell_centers
Expand Down Expand Up @@ -64,7 +64,7 @@ def grid_type(self) -> GridType:
return GridType.CELL

def get_provider_kwargs(
self, torax_mesh: geometry.Grid1D | None = None
self, torax_mesh: torax_pydantic.Grid1D | None = None
) -> dict[str, Any]:
"""Returns the kwargs to be passed to the provider constructor.
Expand Down Expand Up @@ -127,7 +127,7 @@ def get_provider_kwargs(
@abc.abstractmethod
def make_provider(
self,
torax_mesh: geometry.Grid1D | None = None,
torax_mesh: torax_pydantic.Grid1D | None = None,
) -> ProviderT:
"""Builds a RuntimeParamsProvider object from this config.
Expand Down
5 changes: 3 additions & 2 deletions torax/config/build_runtime_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,14 +32,15 @@
from torax.pedestal_model import pydantic_model as pedestal_pydantic_model
from torax.sources import runtime_params as sources_params
from torax.stepper import pydantic_model as stepper_pydantic_model
from torax.torax_pydantic import torax_pydantic
from torax.transport_model import runtime_params as transport_model_params


def build_static_runtime_params_slice(
*,
runtime_params: general_runtime_params_lib.GeneralRuntimeParams,
source_runtime_params: dict[str, sources_params.RuntimeParams],
torax_mesh: geometry.Grid1D,
torax_mesh: torax_pydantic.Grid1D,
stepper: stepper_pydantic_model.Stepper | None = None,
) -> runtime_params_slice.StaticRuntimeParamsSlice:
"""Builds a StaticRuntimeParamsSlice.
Expand Down Expand Up @@ -144,7 +145,7 @@ def __init__(
transport: transport_model_params.RuntimeParams | None = None,
sources: dict[str, sources_params.RuntimeParams] | None = None,
stepper: stepper_pydantic_model.Stepper | None = None,
torax_mesh: geometry.Grid1D | None = None,
torax_mesh: torax_pydantic.Grid1D | None = None,
):
"""Constructs a build_simulation_params.DynamicRuntimeParamsSliceProvider.
Expand Down
3 changes: 1 addition & 2 deletions torax/config/numerics.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
from torax import array_typing
from torax import interpolated_param
from torax.config import base
from torax.geometry import geometry
from torax.torax_pydantic import torax_pydantic
from typing_extensions import override
from typing_extensions import Self
Expand Down Expand Up @@ -166,7 +165,7 @@ class Numerics(base.RuntimeParametersConfig):

@override
def make_provider(
self, torax_mesh: geometry.Grid1D | None = None
self, torax_mesh: torax_pydantic.Grid1D | None = None
) -> NumericsProvider:
return NumericsProvider(**self.get_provider_kwargs(torax_mesh))

Expand Down
3 changes: 1 addition & 2 deletions torax/config/plasma_composition.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@
from torax import interpolated_param
from torax.config import base
from torax.config import config_args
from torax.geometry import geometry
from torax.torax_pydantic import torax_pydantic
from typing_extensions import Self

Expand Down Expand Up @@ -322,7 +321,7 @@ class PlasmaComposition(

def make_provider(
self,
torax_mesh: geometry.Grid1D | None = None,
torax_mesh: torax_pydantic.Grid1D | None = None,
) -> PlasmaCompositionProvider:
if torax_mesh is None:
raise ValueError(
Expand Down
3 changes: 1 addition & 2 deletions torax/config/profile_conditions.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
from torax import interpolated_param
from torax.config import base
from torax.config import config_args
from torax.geometry import geometry
from torax.torax_pydantic import torax_pydantic
from typing_extensions import override
from typing_extensions import Self
Expand Down Expand Up @@ -263,7 +262,7 @@ def __post_init__(self):
@override
def make_provider(
self,
torax_mesh: geometry.Grid1D | None = None,
torax_mesh: torax_pydantic.Grid1D | None = None,
) -> ProfileConditionsProvider:
provider_kwargs = self.get_provider_kwargs(torax_mesh)
if torax_mesh is None:
Expand Down
3 changes: 1 addition & 2 deletions torax/config/runtime_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
from torax.config import numerics as numerics_lib
from torax.config import plasma_composition as plasma_composition_lib
from torax.config import profile_conditions as profile_conditions_lib
from torax.geometry import geometry
from torax.torax_pydantic import torax_pydantic
from typing_extensions import override

Expand Down Expand Up @@ -88,7 +87,7 @@ class GeneralRuntimeParams(base.RuntimeParametersConfig):
output_dir: str | None = None

def make_provider(
self, torax_mesh: geometry.Grid1D | None = None
self, torax_mesh: torax_pydantic.Grid1D | None = None
) -> GeneralRuntimeParamsProvider:
return GeneralRuntimeParamsProvider(**self.get_provider_kwargs(torax_mesh))

Expand Down
4 changes: 3 additions & 1 deletion torax/config/runtime_params_slice.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,9 @@
from torax.pedestal_model import runtime_params as pedestal_model_params
from torax.sources import runtime_params as sources_params
from torax.stepper import runtime_params as stepper_params
from torax.torax_pydantic import torax_pydantic
from torax.transport_model import runtime_params as transport_model_params

# Many of the variables follow scientific or mathematical notation, so disable
# pylint complaints.
# pylint: disable=invalid-name
Expand Down Expand Up @@ -114,7 +116,7 @@ class StaticRuntimeParamsSlice:
# Mapping of source name to source-specific static runtime params.
sources: Mapping[str, sources_params.StaticRuntimeParams]
# Torax mesh used to construct the geometry.
torax_mesh: geometry.Grid1D
torax_mesh: torax_pydantic.Grid1D
# Solve the ion heat equation (ion temperature evolves over time)
ion_heat_eq: bool
# Solve the electron heat equation (electron temperature evolves over time)
Expand Down
3 changes: 2 additions & 1 deletion torax/geometry/circular_geometry.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

import numpy as np
from torax.geometry import geometry
from torax.torax_pydantic import torax_pydantic


# Using invalid-name because we are using the same naming convention as the
Expand Down Expand Up @@ -50,7 +51,7 @@ def build_circular_geometry(
# toroidal flux coordinate.
drho_norm = 1.0 / n_rho
# Define mesh (Slab Uniform 1D with Jacobian = 1)
mesh = geometry.Grid1D.construct(nx=n_rho, dx=drho_norm)
mesh = torax_pydantic.Grid1D.construct(nx=n_rho, dx=drho_norm)
# toroidal flux coordinate (rho) at boundary (last closed flux surface)
rho_b = np.asarray(Rmin)

Expand Down
103 changes: 25 additions & 78 deletions torax/geometry/geometry.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,61 +25,9 @@
import jax
import jax.numpy as jnp
import numpy as np
import pydantic
from torax.torax_pydantic import torax_pydantic


class Grid1D(torax_pydantic.BaseModelFrozen):
"""Data structure defining a 1-D grid of cells with faces.
Construct via `construct` classmethod.
Attributes:
nx: Number of cells.
dx: Distance between cell centers.
face_centers: Coordinates of face centers.
cell_centers: Coordinates of cell centers.
"""

nx: pydantic.PositiveInt
dx: pydantic.PositiveFloat
face_centers: torax_pydantic.NumpyArray1D
cell_centers: torax_pydantic.NumpyArray1D

def __eq__(self, other: Grid1D) -> bool:
return (
self.nx == other.nx
and self.dx == other.dx
and np.array_equal(self.face_centers, other.face_centers)
and np.array_equal(self.cell_centers, other.cell_centers)
)

def __hash__(self) -> int:
return hash((self.nx, self.dx))

@classmethod
def construct(cls, nx: int, dx: float) -> Grid1D:
"""Constructs a Grid1D.
Args:
nx: Number of cells.
dx: Distance between cell centers.
Returns:
grid: A Grid1D with the remaining fields filled in.
"""

# Note: nx needs to be an int so that the shape `nx + 1` is not a Jax
# tracer.

return Grid1D(
nx=nx,
dx=dx,
face_centers=np.linspace(0, nx * dx, nx + 1),
cell_centers=np.linspace(dx * 0.5, (nx - 0.5) * dx, nx),
)


def face_to_cell(face: chex.Array) -> chex.Array:
"""Infers cell values corresponding to a vector of face values.
Expand Down Expand Up @@ -125,8 +73,7 @@ class Geometry:
Attributes:
geometry_type: Type of geometry model used. See `GeometryType` for options.
torax_mesh: `Grid1D` object representing the radial mesh used by TORAX.
Phi: Toroidal magnetic flux at each radial grid point
[:math:`\mathrm{Wb}`].
Phi: Toroidal magnetic flux at each radial grid point [:math:`\mathrm{Wb}`].
Phi_face: Toroidal magnetic flux at each radial face [:math:`\mathrm{Wb}`].
Rmaj: Tokamak major radius (geometric center) [:math:`\mathrm{m}`].
Rmin: Tokamak minor radius [:math:`\mathrm{m}`].
Expand All @@ -150,8 +97,8 @@ class Geometry:
[:math:`\mathrm{m}^2`]. Equal to vpr / (:math:`2 \pi` Rmaj).
spr_face: Derivative of plasma surface area enclosed by each flux surface,
with respect to the normalized toroidal flux coordinate rho_face_norm on
face grid [:math:`\mathrm{m}^2`]. Equal to
vpr_face / (:math:`2 \pi` Rmaj).
face grid [:math:`\mathrm{m}^2`]. Equal to vpr_face / (:math:`2 \pi`
Rmaj).
spr_hires: Derivative of plasma surface area enclosed by each flux surface
on a higher resolution grid, with respect to the normalized toroidal flux
coordinate rho_norm. [:math:`\mathrm{m}^2`].
Expand All @@ -161,17 +108,17 @@ class Geometry:
grid [dimensionless].
g0: Flux surface averaged radial derivative of the plasma volume:
:math:`\langle \nabla V \rangle` on cell grid [:math:`\mathrm{m}^2`].
g0_face: Flux surface averaged :math:`\langle \nabla V \rangle` on the
faces [:math:`\mathrm{m}^2`].
g0_face: Flux surface averaged :math:`\langle \nabla V \rangle` on the faces
[:math:`\mathrm{m}^2`].
g1: Flux surface averaged :math:`\langle (\nabla V)^2 \rangle` on cell grid
[:math:`\mathrm{m}^4`].
g1_face: Flux surface averaged :math:`\langle (\nabla V)^2 \rangle` on the
faces [:math:`\mathrm{m}^4`].
g2: Flux surface averaged :math:`\langle (\nabla V)^2 / R^2 \rangle` on
cell grid [:math:`\mathrm{m}^2`], where R is the major radius along the
flux surface being averaged.
g2_face: Flux surface averaged :math:`\langle (\nabla V)^2 / R^2 \rangle`
on the faces [:math:`\mathrm{m}^2`].
g2: Flux surface averaged :math:`\langle (\nabla V)^2 / R^2 \rangle` on cell
grid [:math:`\mathrm{m}^2`], where R is the major radius along the flux
surface being averaged.
g2_face: Flux surface averaged :math:`\langle (\nabla V)^2 / R^2 \rangle` on
the faces [:math:`\mathrm{m}^2`].
g3: Flux surface averaged :math:`\langle 1 / R^2 \rangle` on cell grid
[:math:`\mathrm{m}^{-2}`].
g3_face: Flux surface averaged :math:`\langle 1 / R^2 \rangle` on the faces
Expand All @@ -195,32 +142,32 @@ class Geometry:
Rin_face: Radius of the flux surface at the inboard side at midplane
[:math:`\mathrm{m}`] on face grid.
Rout: Radius of the flux surface at the outboard side at midplane
[:math:`\mathrm{m}`] on cell grid. Outboard side is defined as the
maximum radial extent of the flux surface.
[:math:`\mathrm{m}`] on cell grid. Outboard side is defined as the maximum
radial extent of the flux surface.
Rout_face: Radius of the flux surface at the outboard side at midplane
[:math:`\mathrm{m}`] on face grid.
delta_face: Average of upper and lower triangularity of each flux surface
at the faces [dimensionless]. Upper triangularity is defined as
(Rmaj_local - R_upper) / Rmin_local, where Rmaj_local = (Rout+Rin)/2,
Rmin_local = (Rout-Rin)/2, and R_upper is the radial location of the
upper extent of the flux surface. Lower triangularity is defined as
(Rmaj_local - R_lower) / Rmin_local, where R_lower is the radial
location of the lower extent of the flux surface.
delta_face: Average of upper and lower triangularity of each flux surface at
the faces [dimensionless]. Upper triangularity is defined as (Rmaj_local -
R_upper) / Rmin_local, where Rmaj_local = (Rout+Rin)/2, Rmin_local =
(Rout-Rin)/2, and R_upper is the radial location of the upper extent of
the flux surface. Lower triangularity is defined as (Rmaj_local - R_lower)
/ Rmin_local, where R_lower is the radial location of the lower extent of
the flux surface.
elongation: Plasma elongation profile on cell grid [dimensionless].
Elongation is defined as (Z_upper - Z_lower) / (2.0 * Rmin_local),
where Z_upper and Z_lower are the Z coordinates of the upper and lower
extent of the flux surface.
Elongation is defined as (Z_upper - Z_lower) / (2.0 * Rmin_local), where
Z_upper and Z_lower are the Z coordinates of the upper and lower extent of
the flux surface.
elongation_face: Plasma elongation profile on face grid [dimensionless].
Phibdot: Time derivative of the toroidal magnetic flux
[:math:`\mathrm{Wb/s}`]. Calculated across a time interval using ``Phi``
from the Geometry objects at time t and t + dt.
See ``torax.orchestration.step_function`` for more details.
from the Geometry objects at time t and t + dt. See
``torax.orchestration.step_function`` for more details.
_z_magnetic_axis: Vertical position of the magnetic axis
[:math:`\mathrm{m}`].
"""

geometry_type: GeometryType
torax_mesh: Grid1D
torax_mesh: torax_pydantic.Grid1D
Phi: chex.Array
Phi_face: chex.Array
Rmaj: chex.Array
Expand Down
9 changes: 4 additions & 5 deletions torax/geometry/geometry_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
from torax import interpolated_param
from torax import jax_utils
from torax.geometry import geometry

from torax.torax_pydantic import torax_pydantic

# Using invalid-name because we are using the same naming convention as the
# external physics implementations
Expand Down Expand Up @@ -86,7 +86,7 @@ def __call__(
"""

@property
def torax_mesh(self) -> geometry.Grid1D:
def torax_mesh(self) -> torax_pydantic.Grid1D:
"""Returns the mesh used by Torax, this is consistent across time."""


Expand All @@ -103,7 +103,7 @@ def __call__(self, t: chex.Numeric) -> geometry.Geometry:
return self._geo

@property
def torax_mesh(self) -> geometry.Grid1D:
def torax_mesh(self) -> torax_pydantic.Grid1D:
return self._geo.torax_mesh


Expand All @@ -112,7 +112,7 @@ class TimeDependentGeometryProvider:
"""A geometry provider which holds values to interpolate based on time."""

geometry_type: geometry.GeometryType
torax_mesh: geometry.Grid1D
torax_mesh: torax_pydantic.Grid1D
drho_norm: interpolated_param.InterpolatedVarSingleAxis
Phi: interpolated_param.InterpolatedVarSingleAxis
Phi_face: interpolated_param.InterpolatedVarSingleAxis
Expand Down Expand Up @@ -224,4 +224,3 @@ def _get_geometry_base(
def __call__(self, t: chex.Numeric) -> geometry.Geometry:
"""Returns a Geometry instance at the given time."""
return self._get_geometry_base(t, geometry.Geometry)

3 changes: 2 additions & 1 deletion torax/geometry/standard_geometry.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
from torax.geometry import geometry
from torax.geometry import geometry_loader
from torax.geometry import geometry_provider
from torax.torax_pydantic import torax_pydantic

# pylint: disable=invalid-name

Expand Down Expand Up @@ -978,7 +979,7 @@ def build_standard_geometry(
# fill geometry structure
drho_norm = float(rho_norm_intermediate[-1]) / intermediate.n_rho
# normalized grid
mesh = geometry.Grid1D.construct(nx=intermediate.n_rho, dx=drho_norm)
mesh = torax_pydantic.Grid1D.construct(nx=intermediate.n_rho, dx=drho_norm)
rho_b = rho_intermediate[-1] # radius denormalization constant
# helper variables for mesh cells and faces
rho_face_norm = mesh.face_centers
Expand Down
Loading

0 comments on commit 0b1bbd2

Please sign in to comment.