Skip to content

Commit

Permalink
Make geo a required parameter for build_dynamic_runtime_params_slice.
Browse files Browse the repository at this point in the history
It is implicitly required by the structure of `GeneralRuntimeParams` now and less misleading to remove the invalid `None` from the type hint.

PiperOrigin-RevId: 657139590
  • Loading branch information
Nush395 authored and Torax team committed Jul 29, 2024
1 parent 6577dce commit 097a493
Showing 2 changed files with 12 additions and 3 deletions.
4 changes: 2 additions & 2 deletions torax/config/runtime_params_slice.py
Original file line number Diff line number Diff line change
@@ -261,11 +261,11 @@ class StaticRuntimeParamsSlice:

def build_dynamic_runtime_params_slice(
runtime_params: general_runtime_params.GeneralRuntimeParams,
geo: geometry.Geometry,
transport: transport_model_params.RuntimeParams | None = None,
sources: dict[str, sources_params.RuntimeParams] | None = None,
stepper: stepper_params.RuntimeParams | None = None,
t: chex.Numeric | None = None,
geo: geometry.Geometry | None = None,
) -> DynamicRuntimeParamsSlice:
"""Builds a DynamicRuntimeParamsSlice."""
transport = transport or transport_model_params.RuntimeParams()
@@ -377,7 +377,7 @@ def __init__(
def __call__(
self,
t: chex.Numeric,
geo: geometry.Geometry | None = None,
geo: geometry.Geometry,
) -> DynamicRuntimeParamsSlice:
"""Returns a DynamicRuntimeParamsSlice to use during time t of the sim."""
return build_dynamic_runtime_params_slice(
11 changes: 10 additions & 1 deletion torax/tests/test_lib/sim_test_case.py
Original file line number Diff line number Diff line change
@@ -24,6 +24,7 @@
import jax.numpy as jnp
import numpy as np
import torax
from torax import geometry
from torax import sim as sim_lib
from torax import simulation_app
from torax import state as state_lib
@@ -309,6 +310,7 @@ def make_frozen_optimizer_stepper(
source_models: source_models_lib.SourceModels,
runtime_params: general_runtime_params.GeneralRuntimeParams,
transport_params: transport_params_lib.RuntimeParams,
geo: geometry.Geometry,
) -> stepper_lib.Stepper:
"""Makes an optimizer stepper with frozen coefficients.
@@ -322,6 +324,7 @@ def make_frozen_optimizer_stepper(
state evolution equations.
runtime_params: General TORAX runtime input parameters.
transport_params: Runtime params for the transport model.
geo: The geometry of the simulation.
Returns:
Stepper: the stepper.
@@ -332,6 +335,7 @@ def make_frozen_optimizer_stepper(
runtime_params=runtime_params,
transport=transport_params,
sources=source_models_builder.runtime_params,
geo=geo,
)
)
callback_builder = functools.partial(
@@ -349,6 +353,7 @@ def make_frozen_newton_raphson_stepper(
transport_model: transport_model_lib.TransportModel,
source_models: source_models_lib.SourceModels,
runtime_params: general_runtime_params.GeneralRuntimeParams,
geo: geometry.Geometry,
) -> stepper_lib.Stepper:
"""Makes a Newton Raphson stepper with frozen coefficients.
@@ -361,13 +366,17 @@ def make_frozen_newton_raphson_stepper(
source_models: TORAX sources/sinks used to compute profile terms in the
state evolution equations.
runtime_params: General TORAX runtime input parameters.
geo: The geometry of the simulation.
Returns:
Stepper: the stepper.
"""
# Get the dynamic runtime params for the start of the simulation.
dynamic_runtime_params_slice = (
runtime_params_slice.build_dynamic_runtime_params_slice(runtime_params)
runtime_params_slice.build_dynamic_runtime_params_slice(
runtime_params,
geo=geo,
)
)
callback_builder = functools.partial(
sim_lib.FrozenCoeffsCallback,

0 comments on commit 097a493

Please sign in to comment.