diff --git a/src/elisa/model/core/model.py b/src/elisa/model/core/model.py index 8ea64544..775da840 100644 --- a/src/elisa/model/core/model.py +++ b/src/elisa/model/core/model.py @@ -3,14 +3,28 @@ import inspect from abc import ABCMeta, abstractmethod -from typing import Literal +from typing import Any, Literal import jax.numpy as jnp +from elisa.model.core.parameter import ParameterBase from elisa.util.typing import Array, JAXArray +ParamConfig = tuple[tuple[str, str, str, float, float, float, bool, bool], ...] +ExtraKw = tuple[tuple[str, Any], ...] + + +# __all__ = [ +# 'ModelBase', +# 'Model', +# 'CompositeModel', +# 'Component', +# ] + class ModelBase(metaclass=ABCMeta): + """Base model class.""" + def __init__(self): self._id = hex(id(self))[2:] @@ -54,6 +68,7 @@ def __rmul__(self, other: ModelBase) -> CompositeModel: class Model(ModelBase): def __init__(self, component: Component): + self._component = component self._name = str(component.name) self._latex = str(component.latex) self._type = component.type @@ -219,19 +234,23 @@ class ComponentMeta(ABCMeta): """Avoid cumbersome coding for subclass ``__init__``.""" def __new__(cls, name, bases, dct, **kwargs) -> ComponentMeta: - # define subclass __init__ method - if 'config' in dct: - config = dct['config'] + # check config and then define subclass __init__ method + if isinstance(config := dct.get('_config', None), tuple): + if any(len(cfg) != 8 for cfg in config): + raise ValueError( + f'each parameter configuration of {name} should have ' + '8 values: name (str), latex (str), unit (str), ' + 'default (float), min (float), max (float), ' + 'log (bool), fixed (bool)' + ) + init_def = 'self, ' init_body = '' - par_body = '(' for cfg in config: - init_def += cfg[0] + '=None, ' + init_def += f'{cfg[0]}: ParameterBase | None = None, ' init_body += f'{cfg[0]}={cfg[0]}, ' - par_body += f'{cfg[0]}, ' - par_body += ')' - init_def += 'latex=None' + init_def += 'latex: str | None = None' init_body += 'latex=latex' if hasattr(cls, '_extra_kw') and isinstance(cls._extra_kw, tuple): @@ -252,7 +271,7 @@ def __new__(cls, name, bases, dct, **kwargs) -> ComponentMeta: func_code += 'super(type(self), type(self))' func_code += f'.__init__(self, {init_body})\n' - exec(func_code, tmp := {}) + exec(func_code, tmp := {'ParameterBase': ParameterBase}) __init__ = tmp['__init__'] __init__.__qualname__ = f'{name}.__init__' dct['__init__'] = __init__ @@ -260,12 +279,10 @@ def __new__(cls, name, bases, dct, **kwargs) -> ComponentMeta: return super().__new__(cls, name, bases, dct) def __init__(cls, name, bases, dct, **kwargs) -> None: - # restore the signature of __init__ - # see https://stackoverflow.com/a/65385559 + super().__init__(name, bases, dct, **kwargs) sig = inspect.signature(cls.__init__) parameters = tuple(sig.parameters.values()) cls.__signature__ = sig.replace(parameters=parameters[1:]) - super().__init__(name, bases, dct, **kwargs) def __call__(cls, *args, **kwargs) -> Model: # return Model object after Component initialization @@ -273,8 +290,6 @@ def __call__(cls, *args, **kwargs) -> Model: class Component(metaclass=ComponentMeta): - config: tuple[tuple[str, str, str, float, float, float, bool, bool], ...] - def __init__(self, latex: str | None = None, **params) -> None: name = self.__class__.__name__ @@ -285,6 +300,21 @@ def __init__(self, latex: str | None = None, **params) -> None: self.latex = str(latex) self.params = params + @property + @abstractmethod + def _config(self) -> ParamConfig: + """Configuration of parameters.""" + pass + + @property + def _extra_kw(self) -> ExtraKw: + """Extra keywords passed to ``__init__`` method. + + Note that element of inner tuple should respect :func:`repr`. + + """ + return () + @staticmethod @abstractmethod def eval(*args, **kwargs) -> JAXArray: @@ -296,36 +326,131 @@ def type(self) -> Literal['add', 'mul']: pass -class Powerlaw(Component): - r"""Powerlaw function. +class AdditiveComponent(Component, metaclass=ABCMeta): + """Prototype class to define additive component.""" + + def eval(self, egrid: Array, *args, **kwargs) -> JAXArray: + return self.integrate(egrid, *args, **kwargs) + + @abstractmethod + def integrate(self, *args, **kwargs) -> JAXArray: + pass + + @property + def type(self) -> Literal['add']: + """Model type is additive.""" + return 'add' + + +class NumIntAdditive(AdditiveComponent, metaclass=ABCMeta): + _extra_kw = (('method', 'trapz'),) + + def __init__(self, method='trapz', **kwargs): + self.method = method + super().__init__(**kwargs) + + @staticmethod + @abstractmethod + def continnum(egrid: Array, *args, **kwargs) -> JAXArray: + pass + + def integrate(self, egrid: Array, *args, **kwargs) -> JAXArray: + if self.method == 'trapz': + de = egrid[1:] - egrid[:-1] + f_grid = self.continnum(egrid, *args, **kwargs) + return 0.5 * (f_grid[:-1] + f_grid[1:]) * de + + elif self.method == 'simpson': + de = egrid[1:] - egrid[:-1] + e_mid = (egrid[:-1] + egrid[1:]) / 2.0 + f_grid = self.continnum(egrid, *args, **kwargs) + f_mid = self.continnum(e_mid, *args, **kwargs) + return de / 6.0 * (f_grid[:-1] + 4.0 * f_mid + f_grid[1:]) + + else: + raise NotImplementedError(f'integration method "{self.method}"') + + @property + def method(self) -> str: + """Numerical integration method.""" + return self._method + + @method.setter + def method(self, method: str): + method = str(method) + + if method not in ('trapz', 'simpson'): + raise ValueError( + f'available integration methods are "trapz" and "simpson", ' + f'but got "{method}"' + ) + + self._method = method + + +class Powerlaw(AdditiveComponent): + r"""The power law function, defined as .. math:: - \frac{dN(E)}{dA dt dE} = K \frac{E}{E_\mathrm{pivot}}^{\alpha} + N(E) = K \left(\frac{E}{E_\mathrm{pivot}}\right)^{-\alpha}, + + where :math:`E_\mathrm{pivot}` is the pivot energy fixed at 1 keV. Parameters ---------- - alpha : parameter - The photon index. - K : parameter - The normalization. + alpha : ParameterBase + The power law photon index, dimensionless. + K : ParameterBase + The normalization at 1 keV, in units of :math:`\mathrm{cm}^{-2} \, + \mathrm{s}^{-1} \, \mathrm{keV}^{-1}`. latex : str, optional LaTeX representation of the model. The default is as its class name. """ - config = ( - ('alpha', r'\alpha', '', 1.01, -3.0, 10.0, False, False), - ('K', 'K', '1 / (keV s cm2)', 1.0, 1e-10, 1e10, False, False), + _config = ( + ('alpha', r'\alpha', '', -1.01, -10.0, 3.0, False, False), + ('K', 'K', 'cm^-2 s^-1 keV^-1', 1.0, 1e-10, 1e10, False, False), ) @staticmethod - def eval(egrid: Array, alpha, K) -> JAXArray: - return K * jnp.power(egrid, -alpha) + def integrate(egrid, alpha, K) -> JAXArray: + # we ignore the case of alpha = 1.0 + tmp = 1.0 - alpha + f = K / tmp * jnp.power(egrid, tmp) + return f[1:] - f[:-1] - @property - def type(self) -> Literal['add']: - return 'add' +class Cutoffpl(NumIntAdditive): + r"""The cut-off power law function, defined as + + .. math:: + N(E) = K \left(\frac{E}{E_\mathrm{pivot}}\right)^{-\alpha} + \exp\left(-\frac{E}{E_\mathrm{c}}\right), + + where :math:`E_\mathrm{pivot}` is the pivot energy fixed at 1 keV. -if __name__ == '__main__': - m = Powerlaw() + Parameters + ---------- + alpha : ParameterBase, optional + The power law photon index, dimensionless. + Ec : ParameterBase, optional + The folding energy of exponential rolloff, in units of keV. + K : ParameterBase, optional + The normalization at 1 keV, in units of :math:`\mathrm{cm}^{-2} \, + \mathrm{s}^{-1} \, \mathrm{keV}^{-1}`. + latex : str, optional + LaTeX representation of the model. The default is as its class name. + + """ + + _config = ( + ('alpha', r'\alpha', '', 1.0, -3.0, 10.0, False, False), + ('Ec', r'E_\mathrm{c}', 'keV', 15.0, 0.01, 10000.0, False, False), + ('K', 'K', 'cm^-2 s^-1 keV^-1', 1.0, 1e-10, 1e10, False, False), + ) + + @staticmethod + def _continnum(egrid, alpha, Ec, K) -> JAXArray: + e = egrid + return K * jnp.power(e, alpha) * jnp.exp(-e / Ec) diff --git a/src/elisa/model/core/parameter.py b/src/elisa/model/core/parameter.py index 9b624fe6..937dd8bd 100644 --- a/src/elisa/model/core/parameter.py +++ b/src/elisa/model/core/parameter.py @@ -17,6 +17,7 @@ __all__ = [ + 'ParameterBase', 'Parameter', 'UniformParameter', 'ConstantValue',