diff --git a/src/elisa/util/integrate.py b/src/elisa/util/integrate.py new file mode 100644 index 00000000..4c4eaa91 --- /dev/null +++ b/src/elisa/util/integrate.py @@ -0,0 +1,98 @@ +"""Numerical integration.""" +from __future__ import annotations + +from typing import Any, Callable, Literal, get_args + +import jax.numpy as jnp +from quadax import quadcc, quadgk, quadts, romberg, rombergts + +from elisa.util.typing import ( + JAXArray, + JAXFloat, + ModelCompiledFn, + ParamID, + ParamIDValMapping, +) + +AdaptQuadMethod = Literal['quadgk', 'quadcc', 'quadts', 'romberg', 'rombergts'] +IntegralFactory = Callable[[ModelCompiledFn], ModelCompiledFn] + +_QUAD_FN = dict( + zip( + get_args(AdaptQuadMethod), [quadgk, quadcc, quadts, romberg, rombergts] + ) +) + + +def make_integral_factory( + param_id: ParamID, + interval: JAXArray, + method: AdaptQuadMethod = 'quadgk', + kwargs: dict[str, Any] | None = None, +) -> Callable[[ModelCompiledFn], ModelCompiledFn]: + """Get integral factory over the interval. + + Parameters + ---------- + param_id : str + Parameter ID. + interval: array_like + The interval, a 2-element sequence. + method : {'quadgk', 'quadcc', 'quadts', 'romberg', 'rombergts'}, optional + Numerical integration method used to integrate over the parameter. + Available options are: + * 'quadgk' : global adaptive quadrature with Gauss-Konrod rule + * 'quadcc' : global adaptive quadrature with Clenshaw-Curtis rule + * 'quadts' : global adaptive quadrature with trapz tanh-sinh rule + * 'romberg' : Romberg integration + * 'rombergts' : Romberg integration with tanh-sinh (a.k.a. double + exponential) transformation + The default is 'quadgk'. + kwargs : dict, optional + Extra kwargs passed to integration methods. See [1]_ for details. + + Returns + ------- + integral_factory : callable + Given a model function, the integral factory outputs a new model + function with the interval parameter being integrated out. + + References + ---------- + .. [1] `quadax docs `_ + + """ + if method not in _QUAD_FN: + raise ValueError(f'unsupported method: {method}') + + if jnp.shape(interval) != (2,): + raise ValueError('interval must be sequence of length 2') + + quad = _QUAD_FN[method] + interval = jnp.asarray(interval, float) + kwargs = dict(kwargs) if kwargs is not None else {} + + def integral_factory(model_fn: ModelCompiledFn) -> ModelCompiledFn: + """Integrate the model_fn over the interval.""" + + def integrand( + value: JAXFloat, + egrid: JAXArray, + params: ParamIDValMapping, + ) -> JAXArray: + """The integrand function.""" + params[param_id] = value + return model_fn(egrid, params) + + def integral(egrid: JAXArray, params: ParamIDValMapping) -> JAXArray: + """New model_fn with interval param being integrated out.""" + args = (egrid, params) + quad_result = quad(integrand, interval, args, **kwargs)[0] + # when the integrand is independent of the interval parameter, + # the result is unaffected due to the 1/(b-a) factor + return quad_result / (interval[1] - interval[0]) + + return integral + + return integral_factory