From 097a493140bfd9450edcb0b5500ec67088bd76c7 Mon Sep 17 00:00:00 2001 From: Anushan Fernando Date: Mon, 29 Jul 2024 03:38:40 -0700 Subject: [PATCH] Make geo a required parameter for build_dynamic_runtime_params_slice. It is implicitly required by the structure of `GeneralRuntimeParams` now and less misleading to remove the invalid `None` from the type hint. PiperOrigin-RevId: 657139590 --- torax/config/runtime_params_slice.py | 4 ++-- torax/tests/test_lib/sim_test_case.py | 11 ++++++++++- 2 files changed, 12 insertions(+), 3 deletions(-) diff --git a/torax/config/runtime_params_slice.py b/torax/config/runtime_params_slice.py index 3c729267..3ecc8723 100644 --- a/torax/config/runtime_params_slice.py +++ b/torax/config/runtime_params_slice.py @@ -261,11 +261,11 @@ class StaticRuntimeParamsSlice: def build_dynamic_runtime_params_slice( runtime_params: general_runtime_params.GeneralRuntimeParams, + geo: geometry.Geometry, transport: transport_model_params.RuntimeParams | None = None, sources: dict[str, sources_params.RuntimeParams] | None = None, stepper: stepper_params.RuntimeParams | None = None, t: chex.Numeric | None = None, - geo: geometry.Geometry | None = None, ) -> DynamicRuntimeParamsSlice: """Builds a DynamicRuntimeParamsSlice.""" transport = transport or transport_model_params.RuntimeParams() @@ -377,7 +377,7 @@ def __init__( def __call__( self, t: chex.Numeric, - geo: geometry.Geometry | None = None, + geo: geometry.Geometry, ) -> DynamicRuntimeParamsSlice: """Returns a DynamicRuntimeParamsSlice to use during time t of the sim.""" return build_dynamic_runtime_params_slice( diff --git a/torax/tests/test_lib/sim_test_case.py b/torax/tests/test_lib/sim_test_case.py index e0d91c3d..f4a07905 100644 --- a/torax/tests/test_lib/sim_test_case.py +++ b/torax/tests/test_lib/sim_test_case.py @@ -24,6 +24,7 @@ import jax.numpy as jnp import numpy as np import torax +from torax import geometry from torax import sim as sim_lib from torax import simulation_app from torax import state as state_lib @@ -309,6 +310,7 @@ def make_frozen_optimizer_stepper( source_models: source_models_lib.SourceModels, runtime_params: general_runtime_params.GeneralRuntimeParams, transport_params: transport_params_lib.RuntimeParams, + geo: geometry.Geometry, ) -> stepper_lib.Stepper: """Makes an optimizer stepper with frozen coefficients. @@ -322,6 +324,7 @@ def make_frozen_optimizer_stepper( state evolution equations. runtime_params: General TORAX runtime input parameters. transport_params: Runtime params for the transport model. + geo: The geometry of the simulation. Returns: Stepper: the stepper. @@ -332,6 +335,7 @@ def make_frozen_optimizer_stepper( runtime_params=runtime_params, transport=transport_params, sources=source_models_builder.runtime_params, + geo=geo, ) ) callback_builder = functools.partial( @@ -349,6 +353,7 @@ def make_frozen_newton_raphson_stepper( transport_model: transport_model_lib.TransportModel, source_models: source_models_lib.SourceModels, runtime_params: general_runtime_params.GeneralRuntimeParams, + geo: geometry.Geometry, ) -> stepper_lib.Stepper: """Makes a Newton Raphson stepper with frozen coefficients. @@ -361,13 +366,17 @@ def make_frozen_newton_raphson_stepper( source_models: TORAX sources/sinks used to compute profile terms in the state evolution equations. runtime_params: General TORAX runtime input parameters. + geo: The geometry of the simulation. Returns: Stepper: the stepper. """ # Get the dynamic runtime params for the start of the simulation. dynamic_runtime_params_slice = ( - runtime_params_slice.build_dynamic_runtime_params_slice(runtime_params) + runtime_params_slice.build_dynamic_runtime_params_slice( + runtime_params, + geo=geo, + ) ) callback_builder = functools.partial( sim_lib.FrozenCoeffsCallback,