Skip to content

Commit

Permalink
Typing improvements.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 657216962
  • Loading branch information
sbodenstein authored and Torax team committed Jul 29, 2024
1 parent 204933d commit 1085654
Show file tree
Hide file tree
Showing 18 changed files with 59 additions and 38 deletions.
5 changes: 3 additions & 2 deletions torax/config/runtime_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,16 +18,17 @@

from collections.abc import Mapping
import dataclasses
from typing import TypeAlias

import chex
from torax import interpolated_param


# Type-alias for clarity. While the InterpolatedVarSingleAxis can vary across
# any field, in here, we mainly use it to handle time-dependent parameters.
TimeInterpolated = interpolated_param.TimeInterpolated
TimeInterpolated: TypeAlias = interpolated_param.TimeInterpolated
# Type-alias for clarity for time-and-rho-dependent parameters.
TimeRhoInterpolated = (
TimeRhoInterpolated: TypeAlias = (
interpolated_param.TimeRhoInterpolated
)

Expand Down
4 changes: 3 additions & 1 deletion torax/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
This module saves immutable constants used in various calculations.
"""

from typing import Final

import chex
from jax import numpy as jnp

Expand All @@ -32,7 +34,7 @@ class Constants:
eps: chex.Numeric


CONSTANTS = Constants(
CONSTANTS: Final[Constants] = Constants(
keV2J=1e3 * 1.6e-19,
mp=1.67e-27,
qe=1.6e-19,
Expand Down
8 changes: 5 additions & 3 deletions torax/fvm/block_1d_coeffs.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
calculations specific to plasma physics to provide these coefficients.
"""

from typing import Any, Optional, Protocol
from typing import Any, Optional, Protocol, TypeAlias

import chex
import jax
Expand All @@ -37,11 +37,13 @@
# ((a, b), (c, d)) where a, b, c, d are each jax.Array
#
# ((a, None), (None, d)) : represents a diagonal block matrix
OptionalTupleMatrix = Optional[tuple[tuple[Optional[jax.Array], ...], ...]]
OptionalTupleMatrix: TypeAlias = Optional[
tuple[tuple[Optional[jax.Array], ...], ...]
]


# Alias for better readability.
AuxiliaryOutput = Any
AuxiliaryOutput: TypeAlias = Any


@chex.dataclass(frozen=True)
Expand Down
9 changes: 6 additions & 3 deletions torax/fvm/discrete_system.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,16 +26,19 @@

from __future__ import annotations

from typing import TypeAlias

import jax
from jax import numpy as jnp
from torax.fvm import block_1d_coeffs
from torax.fvm import cell_variable
from torax.fvm import convection_terms
from torax.fvm import diffusion_terms

AuxiliaryOutput = block_1d_coeffs.AuxiliaryOutput
Block1DCoeffs = block_1d_coeffs.Block1DCoeffs
Block1DCoeffsCallback = block_1d_coeffs.Block1DCoeffsCallback

AuxiliaryOutput: TypeAlias = block_1d_coeffs.AuxiliaryOutput
Block1DCoeffs: TypeAlias = block_1d_coeffs.Block1DCoeffs
Block1DCoeffsCallback: TypeAlias = block_1d_coeffs.Block1DCoeffsCallback


def calc_c(
Expand Down
4 changes: 2 additions & 2 deletions torax/fvm/newton_raphson_solve_block.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
"""

import functools
from typing import Callable
from typing import Callable, Final

from absl import logging
import jax
Expand All @@ -45,7 +45,7 @@
# Delta is a vector. If no entry of delta is above this magnitude, we terminate
# the delta loop. This is to avoid getting stuck in an infinite loop in edge
# cases with bad numerics.
MIN_DELTA = 1e-7
MIN_DELTA: Final[float] = 1e-7


def _log_iterations(
Expand Down
6 changes: 4 additions & 2 deletions torax/fvm/optimizer_solve_block.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
See function docstring for details.
"""

from typing import TypeAlias

import jax
from torax import geometry
from torax import state
Expand All @@ -31,8 +33,8 @@
from torax.transport_model import transport_model as transport_model_lib


AuxiliaryOutput = block_1d_coeffs.AuxiliaryOutput
Block1DCoeffsCallback = block_1d_coeffs.Block1DCoeffsCallback
AuxiliaryOutput: TypeAlias = block_1d_coeffs.AuxiliaryOutput
Block1DCoeffsCallback: TypeAlias = block_1d_coeffs.Block1DCoeffsCallback


def optimizer_solve_block(
Expand Down
10 changes: 7 additions & 3 deletions torax/fvm/residual_and_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,10 @@
Residual functions are for use with e.g. the Newton-Raphson method
while loss functions can be minimized using any optimization method.
"""

import functools
from typing import TypeAlias

import chex
import jax
from jax import numpy as jnp
Expand All @@ -38,9 +41,10 @@
from torax.sources import source_profiles
from torax.transport_model import transport_model as transport_model_lib

AuxiliaryOutput = block_1d_coeffs.AuxiliaryOutput
Block1DCoeffs = block_1d_coeffs.Block1DCoeffs
Block1DCoeffsCallback = block_1d_coeffs.Block1DCoeffsCallback

AuxiliaryOutput: TypeAlias = block_1d_coeffs.AuxiliaryOutput
Block1DCoeffs: TypeAlias = block_1d_coeffs.Block1DCoeffs
Block1DCoeffsCallback: TypeAlias = block_1d_coeffs.Block1DCoeffsCallback


@functools.partial(
Expand Down
4 changes: 2 additions & 2 deletions torax/simulation_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def run(_):
import os

import sys
from typing import Any, Callable
from typing import Any, Callable, Final

from absl import logging
import jax
Expand All @@ -75,7 +75,7 @@ def run(_):


# String printed before printing the output file path
WRITE_PREFIX = 'Wrote simulation output to '
WRITE_PREFIX: Final[str] = 'Wrote simulation output to '


# For logging.
Expand Down
3 changes: 2 additions & 1 deletion torax/sources/formula_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,14 @@
from __future__ import annotations

import dataclasses
from typing import TypeAlias

import chex
from torax import interpolated_param
from torax.config import config_args


TimeInterpolated = interpolated_param.TimeInterpolated
TimeInterpolated: TypeAlias = interpolated_param.TimeInterpolated


@dataclasses.dataclass
Expand Down
6 changes: 3 additions & 3 deletions torax/sources/source.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
import enum
import types
import typing
from typing import Any, Callable, Optional, Protocol
from typing import Any, Callable, Optional, Protocol, TypeAlias

# We use Optional here because | doesn't work with string name types.
# We use string name 'source_models.SourceModels' in this file to avoid
Expand All @@ -45,7 +45,7 @@

# Sources implement these functions to be able to provide source profiles.
# pytype bug: 'source_models.SourceModels' not treated as forward reference
SourceProfileFunction = Callable[ # pytype: disable=name-error
SourceProfileFunction: TypeAlias = Callable[ # pytype: disable=name-error
[ # Arguments
runtime_params_slice.DynamicRuntimeParamsSlice, # General config params
runtime_params_lib.DynamicRuntimeParams, # Source-specific params.
Expand All @@ -61,7 +61,7 @@
# Any callable which takes the dynamic runtime_params, geometry, and optional
# core profiles, and outputs a shape corresponding to the expected output of a
# source. See how these types of functions are used in the Source class below.
SourceOutputShapeFunction = Callable[
SourceOutputShapeFunction: TypeAlias = Callable[
[ # Arguments
geometry.Geometry,
],
Expand Down
4 changes: 2 additions & 2 deletions torax/stepper/linear_theta_method.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

from collections.abc import Callable
import dataclasses
from typing import Type
from typing import Type, TypeAlias
import jax
from torax import geometry
from torax import sim
Expand Down Expand Up @@ -134,7 +134,7 @@ def _default_linear_builder(


# Type-alias so that users only need to import this file.
LinearRuntimeParams = runtime_params_lib.RuntimeParams
LinearRuntimeParams: TypeAlias = runtime_params_lib.RuntimeParams


@dataclasses.dataclass(kw_only=True)
Expand Down
4 changes: 2 additions & 2 deletions torax/stepper/runtime_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,14 @@
from __future__ import annotations

import dataclasses
from typing import Any, Iterable
from typing import Any, Iterable, TypeAlias

import chex
from torax import interpolated_param
from torax.config import config_args


TimeInterpolated = interpolated_param.TimeInterpolated
TimeInterpolated: TypeAlias = interpolated_param.TimeInterpolated


@dataclasses.dataclass(kw_only=True)
Expand Down
4 changes: 2 additions & 2 deletions torax/time_step_calculator/array_time_step_calculator.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
A TimeStepCalculator that iterates over entries in an array.
"""

from typing import Union
from typing import TypeAlias, Union

import chex
import jax
Expand All @@ -27,7 +27,7 @@
from torax.config import runtime_params_slice
from torax.time_step_calculator import time_step_calculator

State = int
State: TypeAlias = int


# TODO(b/337844885). Remove the array option and make fixed_dt time-dependent
Expand Down
6 changes: 3 additions & 3 deletions torax/time_step_calculator/chi_time_step_calculator.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
"""

import functools
from typing import Union
from typing import Final, TypeAlias, Union

import jax
from jax import numpy as jnp
Expand All @@ -29,8 +29,8 @@
from torax.time_step_calculator import time_step_calculator

# Dummy state and type for compatibility with time_step_calculator base class
STATE = None
State = type(STATE)
STATE: Final[None] = None
State: TypeAlias = type(STATE)


class ChiTimeStepCalculator(time_step_calculator.TimeStepCalculator[State]):
Expand Down
4 changes: 3 additions & 1 deletion torax/transport_model/base_qlknn_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,12 @@
"""Base class for QLKNN Models."""

import abc
from typing import TypeAlias

import jax


ModelOutput = dict[str, jax.Array]
ModelOutput: TypeAlias = dict[str, jax.Array]


class BaseQLKNNModel(abc.ABC):
Expand Down
4 changes: 2 additions & 2 deletions torax/transport_model/qlknn_10d.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from collections.abc import Mapping
import json
import os
from typing import Any, Callable
from typing import Any, Callable, Final

import flax.linen as nn
import immutabledict
Expand All @@ -29,7 +29,7 @@
# Internal import.

# Move this to common lib.
_ACTIVATION_FNS: Mapping[str, Callable[[jax.Array], jax.Array]] = (
_ACTIVATION_FNS: Final[Mapping[str, Callable[[jax.Array], jax.Array]]] = (
immutabledict.immutabledict({
'relu': nn.relu,
'tanh': nn.tanh,
Expand Down
8 changes: 5 additions & 3 deletions torax/transport_model/qlknn_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
import functools
import logging
import os
from typing import Callable
from typing import Callable, Final

import chex
import jax
Expand All @@ -45,7 +45,7 @@

# Environment variable for the QLKNN model. Used if the model path
# is not set in the config.
MODEL_PATH_ENV_VAR = 'TORAX_QLKNN_MODEL_PATH'
MODEL_PATH_ENV_VAR: Final[str] = 'TORAX_QLKNN_MODEL_PATH'
# If no path is set in either the config or the environment variable, use
# this path.
DEFAULT_MODEL_PATH = '~/qlknn_hyper'
Expand Down Expand Up @@ -108,7 +108,9 @@ class DynamicRuntimeParams(runtime_params_lib.DynamicRuntimeParams):
q_sawtooth_proxy: bool


_EPSILON_NN: float = 1 / 3 # fixed inverse aspect ratio used to train QLKNN10D
_EPSILON_NN: Final[float] = (
1 / 3
) # fixed inverse aspect ratio used to train QLKNN10D


# Memoize, but evict the old model if a new path is given.
Expand Down
4 changes: 3 additions & 1 deletion torax/transport_model/runtime_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@

from __future__ import annotations

from typing import TypeAlias

import chex
from torax import interpolated_param
from torax import jax_utils
Expand All @@ -29,7 +31,7 @@
# Type-alias for clarity. While the InterpolatedVarSingleAxiss can vary across
# any field, in these classes, we mainly use it to handle time-dependent
# parameters.
TimeInterpolated = interpolated_param.TimeInterpolated
TimeInterpolated: TypeAlias = interpolated_param.TimeInterpolated


# pylint: disable=invalid-name
Expand Down

0 comments on commit 1085654

Please sign in to comment.