diff --git a/src/elisa/model/conv.py b/src/elisa/model/conv.py index ff4ebb86..fbc792d9 100644 --- a/src/elisa/model/conv.py +++ b/src/elisa/model/conv.py @@ -1,6 +1,7 @@ """ConvolutionComponent models.""" from __future__ import annotations +from abc import abstractmethod from typing import Callable import jax @@ -12,7 +13,7 @@ __all__ = ['EnFlux', 'PhFlux', 'RedShift', 'VelocityShift'] -class FluxNorm(ConvolutionComponent): +class NormConvolution(ConvolutionComponent): _args = ('emin', 'emax') _kwargs = ('ngrid', 'elog') _supported = frozenset({'add'}) @@ -39,6 +40,36 @@ def __init__( super().__init__(params, latex) + @staticmethod + @abstractmethod + def convolve( + egrid: JAXArray, + params: NameValMapping, + model_fn: Callable[[JAXArray], JAXArray], + flux_egrid: JAXArray, + ) -> JAXArray: + """Convolve a model function. + + Parameters + ---------- + egrid : ndarray + Photon energy grid in units of keV. + params : dict + Parameter dict for the convolution model. + model_fn : callable + The model function to be convolved, which takes the energy grid as + input and returns the model flux over the grid. + flux_egrid : ndarray + Photon energy grid used to calculate flux in units of keV. + + Returns + ------- + value : ndarray + The re-normalized model over `egrid`, in units of cm⁻² s⁻¹ keV⁻¹. + + """ + pass + @property def eval(self) -> ConvolveEval: if self._prev_config == (self.emin, self.emax, self.ngrid, self.elog): @@ -107,7 +138,7 @@ def elog(self, value: bool): self._elog = bool(value) -class PhFlux(FluxNorm): +class PhFlux(NormConvolution): r"""Normalize an additive model by photon flux between `emin` and `emax`. Warnings @@ -125,7 +156,7 @@ class PhFlux(FluxNorm): Flux parameter, in units of cm⁻² s⁻¹. latex : str, optional :math:`\LaTeX` format of the component. Defaults to class name. - ngrid : int or None, optional + ngrid : int, optional The energy grid number to use. The default is 1000. elog : bool, optional Whether to use logarithmically regular energy grids. @@ -152,7 +183,7 @@ def convolve( return F / mflux * flux -class EnFlux(FluxNorm): +class EnFlux(NormConvolution): r"""Normalize an additive model by energy flux between `emin` and `emax`. Warnings @@ -170,7 +201,7 @@ class EnFlux(FluxNorm): Flux parameter, in units of erg cm⁻² s⁻¹. latex : str, optional :math:`\LaTeX` format of the component. Defaults to class name. - ngrid : int or None, optional + ngrid : int, optional The energy grid number to use. The default is 1000. elog : bool, optional Whether to use logarithmically regular energy grids. diff --git a/src/elisa/model/model.py b/src/elisa/model/model.py index 5970c978..441edd32 100644 --- a/src/elisa/model/model.py +++ b/src/elisa/model/model.py @@ -1394,13 +1394,13 @@ def convolve(*args, **kwargs) -> JAXArray: params : dict Parameter dict for the convolution model. model_fn : callable - The model function to be convolved, which takes energy grid as - input and returns the model value at the grid. + The model function to be convolved, which takes the energy grid as + input and returns the model value over the grid. Returns ------- value : ndarray - The convolved model value at the energy grid. + The convolved model value over `egrid`. """ pass