Skip to content

Commit

Permalink
update typing.py
Browse files Browse the repository at this point in the history
  • Loading branch information
wcxve committed Mar 1, 2024
1 parent 34a351d commit a61030a
Showing 1 changed file with 28 additions and 6 deletions.
34 changes: 28 additions & 6 deletions src/elisa/util/typing.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,36 @@
from typing import Union
"""Typing aliases to shorten hints."""
from typing import Callable, TypeVar, Union

import jax.numpy as jnp
import numpy as np
from jax import Array
from jax.typing import ArrayLike

FloatType = jnp.result_type(float)
IntType = jnp.result_type(int)
T = TypeVar('T')

PyFloat = Union[float, np.inexact] # must include 0-d NDArray with float dtype
JAXFloat = Array
Float = Union[PyFloat, JAXFloat]

PRNGKey = Array

NumPyArray = np.ndarray
JAXArray = Array
NumpyArray = np.ndarray
Array = Union[NumpyArray, JAXArray]
Array = Union[NumPyArray, JAXArray]

ArrayLike = ArrayLike

# Type aliases for parameter and model module
CompID = CompName = CompParamName = ParamID = ParamName = str
NameValMapping = dict[CompParamName, JAXFloat]
CompIDParamValMapping = dict[CompID, NameValMapping]
CompIDStrMapping = dict[CompID, str]
ParamIDStrMapping = dict[ParamID, str]
ParamIDValMapping = dict[ParamID, JAXFloat]
CompEval = Callable[[JAXArray, NameValMapping], JAXArray]
ConvolveEval = Callable[
[JAXArray, NameValMapping, Callable[[JAXArray], JAXArray]], JAXArray
]
ModelEval = Callable[[JAXArray, CompIDParamValMapping], JAXArray]
ModelCompiledFn = Callable[[JAXArray, ParamIDValMapping], JAXArray]
NameLaTeX = tuple[str, str]
AdditiveFn = Callable[[JAXArray, ParamIDValMapping], dict[NameLaTeX, JAXArray]]

0 comments on commit a61030a

Please sign in to comment.