Skip to content

Commit

Permalink
numerical integration for integrating parameter
Browse files Browse the repository at this point in the history
  • Loading branch information
wcxve committed Mar 1, 2024
1 parent d9fdd1d commit d9fe47e
Showing 1 changed file with 98 additions and 0 deletions.
98 changes: 98 additions & 0 deletions src/elisa/util/integrate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
"""Numerical integration."""
from __future__ import annotations

from typing import Any, Callable, Literal, get_args

import jax.numpy as jnp
from quadax import quadcc, quadgk, quadts, romberg, rombergts

from elisa.util.typing import (
JAXArray,
JAXFloat,
ModelCompiledFn,
ParamID,
ParamIDValMapping,
)

AdaptQuadMethod = Literal['quadgk', 'quadcc', 'quadts', 'romberg', 'rombergts']
IntegralFactory = Callable[[ModelCompiledFn], ModelCompiledFn]

_QUAD_FN = dict(
zip(
get_args(AdaptQuadMethod), [quadgk, quadcc, quadts, romberg, rombergts]
)
)


def make_integral_factory(
param_id: ParamID,
interval: JAXArray,
method: AdaptQuadMethod = 'quadgk',
kwargs: dict[str, Any] | None = None,
) -> Callable[[ModelCompiledFn], ModelCompiledFn]:
"""Get integral factory over the interval.
Parameters
----------
param_id : str
Parameter ID.
interval: array_like
The interval, a 2-element sequence.
method : {'quadgk', 'quadcc', 'quadts', 'romberg', 'rombergts'}, optional
Numerical integration method used to integrate over the parameter.
Available options are:
* 'quadgk' : global adaptive quadrature with Gauss-Konrod rule
* 'quadcc' : global adaptive quadrature with Clenshaw-Curtis rule
* 'quadts' : global adaptive quadrature with trapz tanh-sinh rule
* 'romberg' : Romberg integration
* 'rombergts' : Romberg integration with tanh-sinh (a.k.a. double
exponential) transformation
The default is 'quadgk'.
kwargs : dict, optional
Extra kwargs passed to integration methods. See [1]_ for details.
Returns
-------
integral_factory : callable
Given a model function, the integral factory outputs a new model
function with the interval parameter being integrated out.
References
----------
.. [1] `quadax docs <https://quadax.readthedocs.io/en/latest/api.html
#adaptive-integration-of-a-callable-function-or-method>`_
"""
if method not in _QUAD_FN:
raise ValueError(f'unsupported method: {method}')

if jnp.shape(interval) != (2,):
raise ValueError('interval must be sequence of length 2')

quad = _QUAD_FN[method]
interval = jnp.asarray(interval, float)
kwargs = dict(kwargs) if kwargs is not None else {}

def integral_factory(model_fn: ModelCompiledFn) -> ModelCompiledFn:
"""Integrate the model_fn over the interval."""

def integrand(
value: JAXFloat,
egrid: JAXArray,
params: ParamIDValMapping,
) -> JAXArray:
"""The integrand function."""
params[param_id] = value
return model_fn(egrid, params)

def integral(egrid: JAXArray, params: ParamIDValMapping) -> JAXArray:
"""New model_fn with interval param being integrated out."""
args = (egrid, params)
quad_result = quad(integrand, interval, args, **kwargs)[0]
# when the integrand is independent of the interval parameter,
# the result is unaffected due to the 1/(b-a) factor
return quad_result / (interval[1] - interval[0])

return integral

return integral_factory

0 comments on commit d9fe47e

Please sign in to comment.