diff --git a/src/elisa/util/typing.py b/src/elisa/util/typing.py index 9bd7f5fd..978afc9a 100644 --- a/src/elisa/util/typing.py +++ b/src/elisa/util/typing.py @@ -1,14 +1,36 @@ -from typing import Union +"""Typing aliases to shorten hints.""" +from typing import Callable, TypeVar, Union -import jax.numpy as jnp import numpy as np from jax import Array +from jax.typing import ArrayLike -FloatType = jnp.result_type(float) -IntType = jnp.result_type(int) +T = TypeVar('T') +PyFloat = Union[float, np.inexact] # must include 0-d NDArray with float dtype JAXFloat = Array +Float = Union[PyFloat, JAXFloat] + PRNGKey = Array + +NumPyArray = np.ndarray JAXArray = Array -NumpyArray = np.ndarray -Array = Union[NumpyArray, JAXArray] +Array = Union[NumPyArray, JAXArray] + +ArrayLike = ArrayLike + +# Type aliases for parameter and model module +CompID = CompName = CompParamName = ParamID = ParamName = str +NameValMapping = dict[CompParamName, JAXFloat] +CompIDParamValMapping = dict[CompID, NameValMapping] +CompIDStrMapping = dict[CompID, str] +ParamIDStrMapping = dict[ParamID, str] +ParamIDValMapping = dict[ParamID, JAXFloat] +CompEval = Callable[[JAXArray, NameValMapping], JAXArray] +ConvolveEval = Callable[ + [JAXArray, NameValMapping, Callable[[JAXArray], JAXArray]], JAXArray +] +ModelEval = Callable[[JAXArray, CompIDParamValMapping], JAXArray] +ModelCompiledFn = Callable[[JAXArray, ParamIDValMapping], JAXArray] +NameLaTeX = tuple[str, str] +AdditiveFn = Callable[[JAXArray, ParamIDValMapping], dict[NameLaTeX, JAXArray]]