From 10856549e85b500e21da5f068f7d0b235d59159c Mon Sep 17 00:00:00 2001 From: Sebastian Bodenstein Date: Mon, 29 Jul 2024 08:57:08 -0700 Subject: [PATCH] Typing improvements. PiperOrigin-RevId: 657216962 --- torax/config/runtime_params.py | 5 +++-- torax/constants.py | 4 +++- torax/fvm/block_1d_coeffs.py | 8 +++++--- torax/fvm/discrete_system.py | 9 ++++++--- torax/fvm/newton_raphson_solve_block.py | 4 ++-- torax/fvm/optimizer_solve_block.py | 6 ++++-- torax/fvm/residual_and_loss.py | 10 +++++++--- torax/simulation_app.py | 4 ++-- torax/sources/formula_config.py | 3 ++- torax/sources/source.py | 6 +++--- torax/stepper/linear_theta_method.py | 4 ++-- torax/stepper/runtime_params.py | 4 ++-- .../time_step_calculator/array_time_step_calculator.py | 4 ++-- torax/time_step_calculator/chi_time_step_calculator.py | 6 +++--- torax/transport_model/base_qlknn_model.py | 4 +++- torax/transport_model/qlknn_10d.py | 4 ++-- torax/transport_model/qlknn_wrapper.py | 8 +++++--- torax/transport_model/runtime_params.py | 4 +++- 18 files changed, 59 insertions(+), 38 deletions(-) diff --git a/torax/config/runtime_params.py b/torax/config/runtime_params.py index 77085367..835c5292 100644 --- a/torax/config/runtime_params.py +++ b/torax/config/runtime_params.py @@ -18,6 +18,7 @@ from collections.abc import Mapping import dataclasses +from typing import TypeAlias import chex from torax import interpolated_param @@ -25,9 +26,9 @@ # Type-alias for clarity. While the InterpolatedVarSingleAxis can vary across # any field, in here, we mainly use it to handle time-dependent parameters. -TimeInterpolated = interpolated_param.TimeInterpolated +TimeInterpolated: TypeAlias = interpolated_param.TimeInterpolated # Type-alias for clarity for time-and-rho-dependent parameters. -TimeRhoInterpolated = ( +TimeRhoInterpolated: TypeAlias = ( interpolated_param.TimeRhoInterpolated ) diff --git a/torax/constants.py b/torax/constants.py index 63787f8b..0768f326 100644 --- a/torax/constants.py +++ b/torax/constants.py @@ -17,6 +17,8 @@ This module saves immutable constants used in various calculations. """ +from typing import Final + import chex from jax import numpy as jnp @@ -32,7 +34,7 @@ class Constants: eps: chex.Numeric -CONSTANTS = Constants( +CONSTANTS: Final[Constants] = Constants( keV2J=1e3 * 1.6e-19, mp=1.67e-27, qe=1.6e-19, diff --git a/torax/fvm/block_1d_coeffs.py b/torax/fvm/block_1d_coeffs.py index b83e2929..26716444 100644 --- a/torax/fvm/block_1d_coeffs.py +++ b/torax/fvm/block_1d_coeffs.py @@ -20,7 +20,7 @@ calculations specific to plasma physics to provide these coefficients. """ -from typing import Any, Optional, Protocol +from typing import Any, Optional, Protocol, TypeAlias import chex import jax @@ -37,11 +37,13 @@ # ((a, b), (c, d)) where a, b, c, d are each jax.Array # # ((a, None), (None, d)) : represents a diagonal block matrix -OptionalTupleMatrix = Optional[tuple[tuple[Optional[jax.Array], ...], ...]] +OptionalTupleMatrix: TypeAlias = Optional[ + tuple[tuple[Optional[jax.Array], ...], ...] +] # Alias for better readability. -AuxiliaryOutput = Any +AuxiliaryOutput: TypeAlias = Any @chex.dataclass(frozen=True) diff --git a/torax/fvm/discrete_system.py b/torax/fvm/discrete_system.py index 05980c09..99d2977b 100644 --- a/torax/fvm/discrete_system.py +++ b/torax/fvm/discrete_system.py @@ -26,6 +26,8 @@ from __future__ import annotations +from typing import TypeAlias + import jax from jax import numpy as jnp from torax.fvm import block_1d_coeffs @@ -33,9 +35,10 @@ from torax.fvm import convection_terms from torax.fvm import diffusion_terms -AuxiliaryOutput = block_1d_coeffs.AuxiliaryOutput -Block1DCoeffs = block_1d_coeffs.Block1DCoeffs -Block1DCoeffsCallback = block_1d_coeffs.Block1DCoeffsCallback + +AuxiliaryOutput: TypeAlias = block_1d_coeffs.AuxiliaryOutput +Block1DCoeffs: TypeAlias = block_1d_coeffs.Block1DCoeffs +Block1DCoeffsCallback: TypeAlias = block_1d_coeffs.Block1DCoeffsCallback def calc_c( diff --git a/torax/fvm/newton_raphson_solve_block.py b/torax/fvm/newton_raphson_solve_block.py index ee9f78ed..d60151ec 100644 --- a/torax/fvm/newton_raphson_solve_block.py +++ b/torax/fvm/newton_raphson_solve_block.py @@ -18,7 +18,7 @@ """ import functools -from typing import Callable +from typing import Callable, Final from absl import logging import jax @@ -45,7 +45,7 @@ # Delta is a vector. If no entry of delta is above this magnitude, we terminate # the delta loop. This is to avoid getting stuck in an infinite loop in edge # cases with bad numerics. -MIN_DELTA = 1e-7 +MIN_DELTA: Final[float] = 1e-7 def _log_iterations( diff --git a/torax/fvm/optimizer_solve_block.py b/torax/fvm/optimizer_solve_block.py index aec5e0c2..5bbd2627 100644 --- a/torax/fvm/optimizer_solve_block.py +++ b/torax/fvm/optimizer_solve_block.py @@ -16,6 +16,8 @@ See function docstring for details. """ +from typing import TypeAlias + import jax from torax import geometry from torax import state @@ -31,8 +33,8 @@ from torax.transport_model import transport_model as transport_model_lib -AuxiliaryOutput = block_1d_coeffs.AuxiliaryOutput -Block1DCoeffsCallback = block_1d_coeffs.Block1DCoeffsCallback +AuxiliaryOutput: TypeAlias = block_1d_coeffs.AuxiliaryOutput +Block1DCoeffsCallback: TypeAlias = block_1d_coeffs.Block1DCoeffsCallback def optimizer_solve_block( diff --git a/torax/fvm/residual_and_loss.py b/torax/fvm/residual_and_loss.py index 66649dc9..29102952 100644 --- a/torax/fvm/residual_and_loss.py +++ b/torax/fvm/residual_and_loss.py @@ -19,7 +19,10 @@ Residual functions are for use with e.g. the Newton-Raphson method while loss functions can be minimized using any optimization method. """ + import functools +from typing import TypeAlias + import chex import jax from jax import numpy as jnp @@ -38,9 +41,10 @@ from torax.sources import source_profiles from torax.transport_model import transport_model as transport_model_lib -AuxiliaryOutput = block_1d_coeffs.AuxiliaryOutput -Block1DCoeffs = block_1d_coeffs.Block1DCoeffs -Block1DCoeffsCallback = block_1d_coeffs.Block1DCoeffsCallback + +AuxiliaryOutput: TypeAlias = block_1d_coeffs.AuxiliaryOutput +Block1DCoeffs: TypeAlias = block_1d_coeffs.Block1DCoeffs +Block1DCoeffsCallback: TypeAlias = block_1d_coeffs.Block1DCoeffsCallback @functools.partial( diff --git a/torax/simulation_app.py b/torax/simulation_app.py index e7b0a0fe..7150b566 100644 --- a/torax/simulation_app.py +++ b/torax/simulation_app.py @@ -56,7 +56,7 @@ def run(_): import os import sys -from typing import Any, Callable +from typing import Any, Callable, Final from absl import logging import jax @@ -75,7 +75,7 @@ def run(_): # String printed before printing the output file path -WRITE_PREFIX = 'Wrote simulation output to ' +WRITE_PREFIX: Final[str] = 'Wrote simulation output to ' # For logging. diff --git a/torax/sources/formula_config.py b/torax/sources/formula_config.py index 311713b6..b4ec65db 100644 --- a/torax/sources/formula_config.py +++ b/torax/sources/formula_config.py @@ -17,13 +17,14 @@ from __future__ import annotations import dataclasses +from typing import TypeAlias import chex from torax import interpolated_param from torax.config import config_args -TimeInterpolated = interpolated_param.TimeInterpolated +TimeInterpolated: TypeAlias = interpolated_param.TimeInterpolated @dataclasses.dataclass diff --git a/torax/sources/source.py b/torax/sources/source.py index db0aa5af..8cc4b227 100644 --- a/torax/sources/source.py +++ b/torax/sources/source.py @@ -27,7 +27,7 @@ import enum import types import typing -from typing import Any, Callable, Optional, Protocol +from typing import Any, Callable, Optional, Protocol, TypeAlias # We use Optional here because | doesn't work with string name types. # We use string name 'source_models.SourceModels' in this file to avoid @@ -45,7 +45,7 @@ # Sources implement these functions to be able to provide source profiles. # pytype bug: 'source_models.SourceModels' not treated as forward reference -SourceProfileFunction = Callable[ # pytype: disable=name-error +SourceProfileFunction: TypeAlias = Callable[ # pytype: disable=name-error [ # Arguments runtime_params_slice.DynamicRuntimeParamsSlice, # General config params runtime_params_lib.DynamicRuntimeParams, # Source-specific params. @@ -61,7 +61,7 @@ # Any callable which takes the dynamic runtime_params, geometry, and optional # core profiles, and outputs a shape corresponding to the expected output of a # source. See how these types of functions are used in the Source class below. -SourceOutputShapeFunction = Callable[ +SourceOutputShapeFunction: TypeAlias = Callable[ [ # Arguments geometry.Geometry, ], diff --git a/torax/stepper/linear_theta_method.py b/torax/stepper/linear_theta_method.py index baa38810..a5005532 100644 --- a/torax/stepper/linear_theta_method.py +++ b/torax/stepper/linear_theta_method.py @@ -16,7 +16,7 @@ from collections.abc import Callable import dataclasses -from typing import Type +from typing import Type, TypeAlias import jax from torax import geometry from torax import sim @@ -134,7 +134,7 @@ def _default_linear_builder( # Type-alias so that users only need to import this file. -LinearRuntimeParams = runtime_params_lib.RuntimeParams +LinearRuntimeParams: TypeAlias = runtime_params_lib.RuntimeParams @dataclasses.dataclass(kw_only=True) diff --git a/torax/stepper/runtime_params.py b/torax/stepper/runtime_params.py index e6e1f40e..5723a20e 100644 --- a/torax/stepper/runtime_params.py +++ b/torax/stepper/runtime_params.py @@ -17,14 +17,14 @@ from __future__ import annotations import dataclasses -from typing import Any, Iterable +from typing import Any, Iterable, TypeAlias import chex from torax import interpolated_param from torax.config import config_args -TimeInterpolated = interpolated_param.TimeInterpolated +TimeInterpolated: TypeAlias = interpolated_param.TimeInterpolated @dataclasses.dataclass(kw_only=True) diff --git a/torax/time_step_calculator/array_time_step_calculator.py b/torax/time_step_calculator/array_time_step_calculator.py index c8a911c9..630da2e8 100644 --- a/torax/time_step_calculator/array_time_step_calculator.py +++ b/torax/time_step_calculator/array_time_step_calculator.py @@ -17,7 +17,7 @@ A TimeStepCalculator that iterates over entries in an array. """ -from typing import Union +from typing import TypeAlias, Union import chex import jax @@ -27,7 +27,7 @@ from torax.config import runtime_params_slice from torax.time_step_calculator import time_step_calculator -State = int +State: TypeAlias = int # TODO(b/337844885). Remove the array option and make fixed_dt time-dependent diff --git a/torax/time_step_calculator/chi_time_step_calculator.py b/torax/time_step_calculator/chi_time_step_calculator.py index a2df6224..ec9540df 100644 --- a/torax/time_step_calculator/chi_time_step_calculator.py +++ b/torax/time_step_calculator/chi_time_step_calculator.py @@ -18,7 +18,7 @@ """ import functools -from typing import Union +from typing import Final, TypeAlias, Union import jax from jax import numpy as jnp @@ -29,8 +29,8 @@ from torax.time_step_calculator import time_step_calculator # Dummy state and type for compatibility with time_step_calculator base class -STATE = None -State = type(STATE) +STATE: Final[None] = None +State: TypeAlias = type(STATE) class ChiTimeStepCalculator(time_step_calculator.TimeStepCalculator[State]): diff --git a/torax/transport_model/base_qlknn_model.py b/torax/transport_model/base_qlknn_model.py index fe4c9dfd..7bebbb63 100644 --- a/torax/transport_model/base_qlknn_model.py +++ b/torax/transport_model/base_qlknn_model.py @@ -14,10 +14,12 @@ """Base class for QLKNN Models.""" import abc +from typing import TypeAlias + import jax -ModelOutput = dict[str, jax.Array] +ModelOutput: TypeAlias = dict[str, jax.Array] class BaseQLKNNModel(abc.ABC): diff --git a/torax/transport_model/qlknn_10d.py b/torax/transport_model/qlknn_10d.py index 0c0b1abb..4be2800d 100644 --- a/torax/transport_model/qlknn_10d.py +++ b/torax/transport_model/qlknn_10d.py @@ -18,7 +18,7 @@ from collections.abc import Mapping import json import os -from typing import Any, Callable +from typing import Any, Callable, Final import flax.linen as nn import immutabledict @@ -29,7 +29,7 @@ # Internal import. # Move this to common lib. -_ACTIVATION_FNS: Mapping[str, Callable[[jax.Array], jax.Array]] = ( +_ACTIVATION_FNS: Final[Mapping[str, Callable[[jax.Array], jax.Array]]] = ( immutabledict.immutabledict({ 'relu': nn.relu, 'tanh': nn.tanh, diff --git a/torax/transport_model/qlknn_wrapper.py b/torax/transport_model/qlknn_wrapper.py index e9ba33a0..543d22ee 100644 --- a/torax/transport_model/qlknn_wrapper.py +++ b/torax/transport_model/qlknn_wrapper.py @@ -26,7 +26,7 @@ import functools import logging import os -from typing import Callable +from typing import Callable, Final import chex import jax @@ -45,7 +45,7 @@ # Environment variable for the QLKNN model. Used if the model path # is not set in the config. -MODEL_PATH_ENV_VAR = 'TORAX_QLKNN_MODEL_PATH' +MODEL_PATH_ENV_VAR: Final[str] = 'TORAX_QLKNN_MODEL_PATH' # If no path is set in either the config or the environment variable, use # this path. DEFAULT_MODEL_PATH = '~/qlknn_hyper' @@ -108,7 +108,9 @@ class DynamicRuntimeParams(runtime_params_lib.DynamicRuntimeParams): q_sawtooth_proxy: bool -_EPSILON_NN: float = 1 / 3 # fixed inverse aspect ratio used to train QLKNN10D +_EPSILON_NN: Final[float] = ( + 1 / 3 +) # fixed inverse aspect ratio used to train QLKNN10D # Memoize, but evict the old model if a new path is given. diff --git a/torax/transport_model/runtime_params.py b/torax/transport_model/runtime_params.py index c752b97a..43b50106 100644 --- a/torax/transport_model/runtime_params.py +++ b/torax/transport_model/runtime_params.py @@ -20,6 +20,8 @@ from __future__ import annotations +from typing import TypeAlias + import chex from torax import interpolated_param from torax import jax_utils @@ -29,7 +31,7 @@ # Type-alias for clarity. While the InterpolatedVarSingleAxiss can vary across # any field, in these classes, we mainly use it to handle time-dependent # parameters. -TimeInterpolated = interpolated_param.TimeInterpolated +TimeInterpolated: TypeAlias = interpolated_param.TimeInterpolated # pylint: disable=invalid-name