Skip to content

Commit

Permalink
feat: custom models with Python callback (#74)
Browse files Browse the repository at this point in the history
  • Loading branch information
wcxve authored May 27, 2024
1 parent 221dab5 commit a133160
Show file tree
Hide file tree
Showing 4 changed files with 158 additions and 2 deletions.
3 changes: 3 additions & 0 deletions src/elisa/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@
ConvolutionComponent as ConvolutionComponent,
NumIntAdditive as NumIntAdditive,
NumIntMultiplicative as NumIntMultiplicative,
ParamConfig as ParamConfig,
PyAnaInt as PyAnaInt,
PyNumInt as PyNumInt,
)
from .models.parameter import (
CompositeParameter as CompositeParameter,
Expand Down
3 changes: 3 additions & 0 deletions src/elisa/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@
ConvolutionComponent as ConvolutionComponent,
NumIntAdditive as NumIntAdditive,
NumIntMultiplicative as NumIntMultiplicative,
ParamConfig as ParamConfig,
PyAnaInt as PyAnaInt,
PyNumInt as PyNumInt,
)
from .mul import * # noqa: F403
from .parameter import (
Expand Down
102 changes: 101 additions & 1 deletion src/elisa/models/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

from elisa.data.base import ObservationData, ResponseData, SpectrumData
from elisa.models.parameter import Parameter, UniformParameter
from elisa.util.misc import build_namespace, make_pretty_table
from elisa.util.misc import build_namespace, define_fdjvp, make_pretty_table

if TYPE_CHECKING:
from typing import Any, Callable, Literal
Expand Down Expand Up @@ -1979,6 +1979,106 @@ def convolve(*args, **kwargs) -> JAXArray:
pass


class PyComponent(Component):
"""Prototype component with pure Python expression defined."""

_kwargs: tuple[str, ...] = ('grad_method',)

def __init__(
self,
params: dict,
latex: str | None,
grad_method: Literal['central', 'forward'] | None,
):
self.grad_method = grad_method

super().__init__(params, latex)

@property
def grad_method(self) -> Literal['central', 'forward']:
"""Numerical differentiation method."""
return self._grad_method

@grad_method.setter
def grad_method(self, value: Literal['central', 'forward'] | None):
if value is None:
value: Literal['central'] = 'central'

if value not in {'central', 'forward'}:
raise ValueError(
f"supported methods are 'central' and 'forward', but got "
f"'{value}'"
)
self._grad_method = value


class PyAnaInt(PyComponent, AnalyticalIntegral):
"""Prototype component with python integral expression defined."""

@property
def eval(self) -> CompEval:
if self._integral_jit is None:
integral_fn = self.integral

def eval_integral(egrid, params):
egrid = np.asarray(egrid)
params = {k: np.asarray(v) for k, v in params.items()}
return integral_fn(egrid, params)

def integral(egrid: JAXArray, params: NameValMapping) -> JAXArray:
shape_dtype = jax.ShapeDtypeStruct(
(egrid.size - 1,), egrid.dtype
)
return jax.pure_callback(
eval_integral, shape_dtype, egrid, params
)

self._integral_jit = jax.jit(
define_fdjvp(jax.jit(integral), self.grad_method)
)

return self._integral_jit


class PyNumInt(PyComponent, NumericalIntegral):
"""Prototype component with python continuum expression defined."""

_kwargs = ('method', 'grad_method')

def __init__(
self,
params: dict,
latex: str | None,
method: Literal['trapz', 'simpson'] | None,
grad_method: Literal['central', 'forward'] | None,
):
super().__init__(params, latex, grad_method)
self.method = 'trapz' if method is None else method

@property
def eval(self) -> CompEval:
if self._continuum_jit is None:
# continuum is assumed to be a pure function, independent of self
continuum_fn = self.continuum

def eval_continuum(egrid, params):
egrid = np.asarray(egrid)
params = {k: np.asarray(v) for k, v in params.items()}
return continuum_fn(egrid, params)

def continuum(egrid: JAXArray, params: NameValMapping) -> JAXArray:
shape_dtype = jax.ShapeDtypeStruct(egrid.shape, egrid.dtype)
return jax.pure_callback(
eval_continuum, shape_dtype, egrid, params
)

self._continuum_jit = jax.jit(
define_fdjvp(jax.jit(continuum), self.grad_method)
)

return self._make_integral(self._continuum_jit)


def get_model_info(
comps: Sequence[Component],
cid_to_name: CompIDStrMapping,
Expand Down
52 changes: 51 additions & 1 deletion tests/test_model.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import jax
import numpy as np
from astropy.cosmology import Planck18
from astropy.units import Unit

from elisa import ConstantValue
from elisa import ConstantValue, ParamConfig, PyAnaInt, PyNumInt
from elisa.models import PhAbs, PLPhFlux, PowerLaw, ZAShift


Expand All @@ -25,6 +26,55 @@ def test_str_repr():
assert model._repr_html_() == compiled._repr_html_()


def test_custom_model():
class PL1(PyAnaInt):
_config = (
ParamConfig('alpha', r'\alpha', '', 1.01, -5.0, 5.0),
ParamConfig('K', 'K', '', 1.0, 1e-10, 1e10),
)
type: str = 'add'

@staticmethod
def integral(egrid, params):
alpha = params['alpha']
K = params['K']
one_minus_alpha = 1.0 - alpha
f = egrid**one_minus_alpha / one_minus_alpha
return K * (f[1:] - f[:-1])

class PL2(PyNumInt):
_config = (
ParamConfig('alpha', r'\alpha', '', 1.01, -5.0, 5.0),
ParamConfig('K', 'K', '', 1.0, 1e-10, 1e10),
)
type: str = 'add'

@staticmethod
def continuum(egrid, params):
return params['K'] * egrid ** -params['alpha']

egrid = np.geomspace(1, 10, 1000)
m0 = PowerLaw().compile()
m1 = PL1().compile()
m2 = PL2().compile()
assert np.allclose(m0.eval(egrid), m1.eval(egrid))
assert np.allclose(m0.eval(egrid), m2.eval(egrid))

grad_fn0 = jax.grad(lambda params: m0.eval(egrid, params).sum())
grad_fn1 = jax.grad(lambda params: m1.eval(egrid, params).sum())
grad_fn2 = jax.grad(lambda params: m2.eval(egrid, params).sum())
params0 = {f'{m0.name}.alpha': 1.01, f'{m0.name}.K': 1.0}
params1 = {f'{m1.name}.alpha': 1.01, f'{m1.name}.K': 1.0}
params2 = {f'{m2.name}.alpha': 1.01, f'{m2.name}.K': 1.0}
grad0 = grad_fn0(params0)
grad1 = grad_fn1(params1)
grad2 = grad_fn2(params2)
assert np.allclose(grad0[f'{m0.name}.alpha'], grad1[f'{m1.name}.alpha'])
assert np.allclose(grad0[f'{m0.name}.K'], grad1[f'{m1.name}.K'])
assert np.allclose(grad0[f'{m0.name}.alpha'], grad2[f'{m2.name}.alpha'])
assert np.allclose(grad0[f'{m0.name}.K'], grad2[f'{m2.name}.K'])


def test_lumin_and_eiso():
def powerlaw(alpha, K, egrid):
egrid = np.array(egrid)
Expand Down

0 comments on commit a133160

Please sign in to comment.