Skip to content

Commit

Permalink
refactor: parameter and model
Browse files Browse the repository at this point in the history
  • Loading branch information
wcxve committed Feb 2, 2024
1 parent c8afa7d commit ecdcdec
Show file tree
Hide file tree
Showing 2 changed files with 158 additions and 32 deletions.
189 changes: 157 additions & 32 deletions src/elisa/model/core/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,28 @@

import inspect
from abc import ABCMeta, abstractmethod
from typing import Literal
from typing import Any, Literal

import jax.numpy as jnp

from elisa.model.core.parameter import ParameterBase
from elisa.util.typing import Array, JAXArray

ParamConfig = tuple[tuple[str, str, str, float, float, float, bool, bool], ...]
ExtraKw = tuple[tuple[str, Any], ...]


# __all__ = [
# 'ModelBase',
# 'Model',
# 'CompositeModel',
# 'Component',
# ]


class ModelBase(metaclass=ABCMeta):
"""Base model class."""

def __init__(self):
self._id = hex(id(self))[2:]

Expand Down Expand Up @@ -54,6 +68,7 @@ def __rmul__(self, other: ModelBase) -> CompositeModel:

class Model(ModelBase):
def __init__(self, component: Component):
self._component = component
self._name = str(component.name)
self._latex = str(component.latex)
self._type = component.type
Expand Down Expand Up @@ -219,19 +234,23 @@ class ComponentMeta(ABCMeta):
"""Avoid cumbersome coding for subclass ``__init__``."""

def __new__(cls, name, bases, dct, **kwargs) -> ComponentMeta:
# define subclass __init__ method
if 'config' in dct:
config = dct['config']
# check config and then define subclass __init__ method
if isinstance(config := dct.get('_config', None), tuple):
if any(len(cfg) != 8 for cfg in config):
raise ValueError(
f'each parameter configuration of {name} should have '
'8 values: name (str), latex (str), unit (str), '
'default (float), min (float), max (float), '
'log (bool), fixed (bool)'
)

init_def = 'self, '
init_body = ''
par_body = '('
for cfg in config:
init_def += cfg[0] + '=None, '
init_def += f'{cfg[0]}: ParameterBase | None = None, '
init_body += f'{cfg[0]}={cfg[0]}, '
par_body += f'{cfg[0]}, '
par_body += ')'

init_def += 'latex=None'
init_def += 'latex: str | None = None'
init_body += 'latex=latex'

if hasattr(cls, '_extra_kw') and isinstance(cls._extra_kw, tuple):
Expand All @@ -252,29 +271,25 @@ def __new__(cls, name, bases, dct, **kwargs) -> ComponentMeta:
func_code += 'super(type(self), type(self))'
func_code += f'.__init__(self, {init_body})\n'

exec(func_code, tmp := {})
exec(func_code, tmp := {'ParameterBase': ParameterBase})
__init__ = tmp['__init__']
__init__.__qualname__ = f'{name}.__init__'
dct['__init__'] = __init__

return super().__new__(cls, name, bases, dct)

def __init__(cls, name, bases, dct, **kwargs) -> None:
# restore the signature of __init__
# see https://stackoverflow.com/a/65385559
super().__init__(name, bases, dct, **kwargs)
sig = inspect.signature(cls.__init__)
parameters = tuple(sig.parameters.values())
cls.__signature__ = sig.replace(parameters=parameters[1:])
super().__init__(name, bases, dct, **kwargs)

def __call__(cls, *args, **kwargs) -> Model:
# return Model object after Component initialization
return Model(super().__call__(*args, **kwargs))


class Component(metaclass=ComponentMeta):
config: tuple[tuple[str, str, str, float, float, float, bool, bool], ...]

def __init__(self, latex: str | None = None, **params) -> None:
name = self.__class__.__name__

Expand All @@ -285,6 +300,21 @@ def __init__(self, latex: str | None = None, **params) -> None:
self.latex = str(latex)
self.params = params

@property
@abstractmethod
def _config(self) -> ParamConfig:
"""Configuration of parameters."""
pass

@property
def _extra_kw(self) -> ExtraKw:
"""Extra keywords passed to ``__init__`` method.
Note that element of inner tuple should respect :func:`repr`.
"""
return ()

@staticmethod
@abstractmethod
def eval(*args, **kwargs) -> JAXArray:
Expand All @@ -296,36 +326,131 @@ def type(self) -> Literal['add', 'mul']:
pass


class Powerlaw(Component):
r"""Powerlaw function.
class AdditiveComponent(Component, metaclass=ABCMeta):
"""Prototype class to define additive component."""

def eval(self, egrid: Array, *args, **kwargs) -> JAXArray:
return self.integrate(egrid, *args, **kwargs)

@abstractmethod
def integrate(self, *args, **kwargs) -> JAXArray:
pass

@property
def type(self) -> Literal['add']:
"""Model type is additive."""
return 'add'


class NumIntAdditive(AdditiveComponent, metaclass=ABCMeta):
_extra_kw = (('method', 'trapz'),)

def __init__(self, method='trapz', **kwargs):
self.method = method
super().__init__(**kwargs)

@staticmethod
@abstractmethod
def continnum(egrid: Array, *args, **kwargs) -> JAXArray:
pass

def integrate(self, egrid: Array, *args, **kwargs) -> JAXArray:
if self.method == 'trapz':
de = egrid[1:] - egrid[:-1]
f_grid = self.continnum(egrid, *args, **kwargs)
return 0.5 * (f_grid[:-1] + f_grid[1:]) * de

elif self.method == 'simpson':
de = egrid[1:] - egrid[:-1]
e_mid = (egrid[:-1] + egrid[1:]) / 2.0
f_grid = self.continnum(egrid, *args, **kwargs)
f_mid = self.continnum(e_mid, *args, **kwargs)
return de / 6.0 * (f_grid[:-1] + 4.0 * f_mid + f_grid[1:])

else:
raise NotImplementedError(f'integration method "{self.method}"')

@property
def method(self) -> str:
"""Numerical integration method."""
return self._method

@method.setter
def method(self, method: str):
method = str(method)

if method not in ('trapz', 'simpson'):
raise ValueError(
f'available integration methods are "trapz" and "simpson", '
f'but got "{method}"'
)

self._method = method


class Powerlaw(AdditiveComponent):
r"""The power law function, defined as
.. math::
\frac{dN(E)}{dA dt dE} = K \frac{E}{E_\mathrm{pivot}}^{\alpha}
N(E) = K \left(\frac{E}{E_\mathrm{pivot}}\right)^{-\alpha},
where :math:`E_\mathrm{pivot}` is the pivot energy fixed at 1 keV.
Parameters
----------
alpha : parameter
The photon index.
K : parameter
The normalization.
alpha : ParameterBase
The power law photon index, dimensionless.
K : ParameterBase
The normalization at 1 keV, in units of :math:`\mathrm{cm}^{-2} \,
\mathrm{s}^{-1} \, \mathrm{keV}^{-1}`.
latex : str, optional
LaTeX representation of the model. The default is as its class name.
"""

config = (
('alpha', r'\alpha', '', 1.01, -3.0, 10.0, False, False),
('K', 'K', '1 / (keV s cm2)', 1.0, 1e-10, 1e10, False, False),
_config = (
('alpha', r'\alpha', '', -1.01, -10.0, 3.0, False, False),
('K', 'K', 'cm^-2 s^-1 keV^-1', 1.0, 1e-10, 1e10, False, False),
)

@staticmethod
def eval(egrid: Array, alpha, K) -> JAXArray:
return K * jnp.power(egrid, -alpha)
def integrate(egrid, alpha, K) -> JAXArray:
# we ignore the case of alpha = 1.0
tmp = 1.0 - alpha
f = K / tmp * jnp.power(egrid, tmp)
return f[1:] - f[:-1]

@property
def type(self) -> Literal['add']:
return 'add'

class Cutoffpl(NumIntAdditive):
r"""The cut-off power law function, defined as
.. math::
N(E) = K \left(\frac{E}{E_\mathrm{pivot}}\right)^{-\alpha}
\exp\left(-\frac{E}{E_\mathrm{c}}\right),
where :math:`E_\mathrm{pivot}` is the pivot energy fixed at 1 keV.
if __name__ == '__main__':
m = Powerlaw()
Parameters
----------
alpha : ParameterBase, optional
The power law photon index, dimensionless.
Ec : ParameterBase, optional
The folding energy of exponential rolloff, in units of keV.
K : ParameterBase, optional
The normalization at 1 keV, in units of :math:`\mathrm{cm}^{-2} \,
\mathrm{s}^{-1} \, \mathrm{keV}^{-1}`.
latex : str, optional
LaTeX representation of the model. The default is as its class name.
"""

_config = (
('alpha', r'\alpha', '', 1.0, -3.0, 10.0, False, False),
('Ec', r'E_\mathrm{c}', 'keV', 15.0, 0.01, 10000.0, False, False),
('K', 'K', 'cm^-2 s^-1 keV^-1', 1.0, 1e-10, 1e10, False, False),
)

@staticmethod
def _continnum(egrid, alpha, Ec, K) -> JAXArray:
e = egrid
return K * jnp.power(e, alpha) * jnp.exp(-e / Ec)
1 change: 1 addition & 0 deletions src/elisa/model/core/parameter.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@


__all__ = [
'ParameterBase',
'Parameter',
'UniformParameter',
'ConstantValue',
Expand Down

0 comments on commit ecdcdec

Please sign in to comment.