From 9bce699f5b6a97dd834f6dfb21684bab087b6169 Mon Sep 17 00:00:00 2001 From: xuewc Date: Sun, 3 Mar 2024 05:01:11 +0800 Subject: [PATCH] docs: update --- .readthedocs.yaml | 2 +- docs/conf.py | 88 ++++++++++-- docs/index.rst | 17 +-- pyproject.toml | 8 +- src/elisa/__init__.py | 7 +- src/elisa/data/__init__.py | 8 +- src/elisa/data/grouping.py | 10 -- src/elisa/data/ogip.py | 60 ++++---- src/elisa/infer/__init__.py | 7 +- src/elisa/infer/fit.py | 10 +- src/elisa/infer/likelihood.py | 2 +- src/elisa/infer/nested_sampling.py | 1 - src/elisa/model/__init__.py | 48 ++----- src/elisa/model/add.py | 220 ++++++++++++++--------------- src/elisa/model/parameter.py | 7 +- src/elisa/util/__init__.py | 7 +- src/elisa/util/config.py | 2 - src/elisa/util/integrate.py | 5 +- src/elisa/util/typing.py | 27 ++++ 19 files changed, 301 insertions(+), 235 deletions(-) diff --git a/.readthedocs.yaml b/.readthedocs.yaml index d8643793..ab6bde29 100644 --- a/.readthedocs.yaml +++ b/.readthedocs.yaml @@ -6,7 +6,7 @@ build: python: "3.9" jobs: pre_build: - - sphinx-apidoc -e -o docs/apidoc src/elisa + - sphinx-apidoc -e -o -T docs/apidoc src/elisa sphinx: configuration: docs/conf.py diff --git a/docs/conf.py b/docs/conf.py index aeefb13f..660e92e5 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -14,8 +14,8 @@ import elisa # noqa: E402 project = 'elisa' -copyright = '2023-2024, Wang-Chen Xue & contributors' -author = 'Wang-Chen Xue' +copyright = '2023-2024, W.-C. Xue & contributors' +author = 'elisa developers' release = elisa.__version__ # -- General configuration --------------------------------------------------- @@ -23,15 +23,24 @@ extensions = [ 'sphinx.ext.autodoc', - 'sphinx.ext.doctest', - 'sphinx.ext.todo', + 'sphinx.ext.autosummary', 'sphinx.ext.coverage', - 'sphinx.ext.mathjax', - 'sphinx.ext.napoleon', + 'sphinx.ext.doctest', + 'sphinx.ext.intersphinx', 'sphinx.ext.viewcode', - 'numpydoc.numpydoc', + 'sphinx_autodoc_typehints', + 'sphinx_book_theme', + 'sphinx_copybutton', + 'sphinx_design', + 'myst_nb', + 'numpydoc', ] +source_suffix = { + '.rst': 'restructuredtext', + '.ipynb': 'myst-nb', +} + templates_path = ['_templates'] exclude_patterns = ['_build'] @@ -39,7 +48,70 @@ # -- Options for HTML output ------------------------------------------------- # https://www.sphinx-doc.org/en/master/usage/configuration.html#options-for-html-output -html_theme = 'sphinx_rtd_theme' +html_theme = 'sphinx_book_theme' +html_theme_options = { + 'github_url': 'https://github.com/wcxve/elisa', + 'repository_url': 'https://github.com/wcxve/elisa', + 'repository_branch': 'main', + 'home_page_in_toc': True, + 'path_to_docs': 'docs', + 'launch_buttons': { + 'binderhub_url': 'https://mybinder.org', + 'colab_url': 'https://colab.research.google.com/', + 'notebook_interface': 'jupyterlab', + }, + 'navigation_with_keys': False, + 'use_edit_page_button': True, + 'use_repository_button': True, + 'use_download_button': True, + 'use_issues_button': True, +} # html_static_path = ['_static'] +html_baseurl = 'https://elisa-lib.readthedocs.io/en/latest/' + +intersphinx_mapping = { + 'python': ('https://docs.python.org/3.9', None), + 'arviz': ('https://python.arviz.org/en/stable/', None), + 'jax': ('https://jax.readthedocs.io/en/latest/', None), + 'matplotlib': ('https://matplotlib.org/stable/', None), + 'numpy': ('https://numpy.org/doc/stable/', None), + 'numpyro': ('https://num.pyro.ai/en/stable/', None), + 'tinygp': ('https://tinygp.readthedocs.io/en/latest/', None), +} + +myst_enable_extensions = [ + 'amsmath', + 'dollarmath', +] + +nb_ipywidgets_js = { + 'https://cdnjs.cloudflare.com/ajax/libs/require.js/2.3.4/require.min.js': { + 'integrity': 'sha256-Ae2Vz/4ePdIu6ZyI/5ZGsYnb+m0JlOmKPjt6XZ9JJkA=', + 'crossorigin': 'anonymous', + }, + 'https://cdn.jsdelivr.net/npm/' + '@jupyter-widgets/html-manager@*/dist/embed-amd.js': { + 'data-jupyter-widgets-cdn': 'https://cdn.jsdelivr.net/npm/', + 'crossorigin': 'anonymous', + }, +} +nb_execution_mode = 'auto' +nb_execution_timeout = -1 numpydoc_attributes_as_param_list = False +numpydoc_class_members_toctree = False +numpydoc_show_class_members = True +numpydoc_show_inherited_class_members = True +numpydoc_xref_param_type = True +numpydoc_xref_ignore = {'optional', 'type_without_description', 'BadException'} +# Run docstring validation as part of build process +# numpydoc_validation_checks = {"all", "GL01", "SA04", "RT03"} +# The name of the Pygments (syntax highlighting) style to use. +pygments_style = 'sphinx' + +numpydoc_xref_aliases = { + 'ParameterBase': 'elisa.model.parameter.ParameterBase', +} + +typehints_use_signature = True +typehints_use_signature_return = True diff --git a/docs/index.rst b/docs/index.rst index 377a8c55..9b350fb5 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -1,20 +1,9 @@ -.. elisa documentation master file, created by - sphinx-quickstart on Tue Nov 14 22:43:20 2023. - You can adapt this file completely to your liking, but it should at least - contain the root `toctree` directive. - Welcome to elisa's documentation! -==================================== +================================= .. toctree:: :maxdepth: 2 :caption: Contents: - apidoc/modules - -Indices and tables -================== - -* :ref:`genindex` -* :ref:`modindex` -* :ref:`search` + quick-start + modules diff --git a/pyproject.toml b/pyproject.toml index c3fb9283..ede2f11c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -44,8 +44,12 @@ xspec = ["xspex"] test = ["pytest", "pytest-cov", "coverage[toml]"] docs = [ "sphinx", - "sphinx_rtd_theme", - "numpydoc" + "sphinx-autodoc-typehints", + "sphinx-book-theme", + "sphinx-copybutton", + "sphinx-design", + "myst-nb", + "numpydoc", ] dev = ["pre-commit>=3.6.0", "ruff"] diff --git a/src/elisa/__init__.py b/src/elisa/__init__.py index 785581ea..52f475fb 100644 --- a/src/elisa/__init__.py +++ b/src/elisa/__init__.py @@ -1,11 +1,8 @@ -from . import data, infer, model, util from .__about__ import __version__ as __version__ -from .data import * # noqa F403 -from .infer import * # noqa F403 +from .data.ogip import * # noqa F403 +from .infer.fit import * # noqa F403 from .model import * # noqa F403 from .util import jax_enable_x64, set_cpu_cores -__all__ = data.__all__ + infer.__all__ + model.__all__ + util.__all__ - jax_enable_x64(True) set_cpu_cores(4) diff --git a/src/elisa/data/__init__.py b/src/elisa/data/__init__.py index a9d5811a..edf9c323 100644 --- a/src/elisa/data/__init__.py +++ b/src/elisa/data/__init__.py @@ -1,3 +1,5 @@ -from .ogip import Data - -__all__ = ['Data'] +from .ogip import ( + Data as Data, + Response as Response, + Spectrum as Spectrum, +) diff --git a/src/elisa/data/grouping.py b/src/elisa/data/grouping.py index f26a991c..3107b56d 100644 --- a/src/elisa/data/grouping.py +++ b/src/elisa/data/grouping.py @@ -6,16 +6,6 @@ from elisa.util.typing import NumPyArray as NDArray -__all__ = [ - 'group_const', - 'group_min', - 'group_sig', - 'group_pos', - 'group_opt', - 'group_optmin', - 'group_optsig', -] - GroupResultType = tuple[NDArray, bool] diff --git a/src/elisa/data/ogip.py b/src/elisa/data/ogip.py index 4cd026d4..46f452e2 100644 --- a/src/elisa/data/ogip.py +++ b/src/elisa/data/ogip.py @@ -19,8 +19,7 @@ ) from elisa.util.typing import NumPyArray as NDArray -__all__ = ['Data', 'Spectrum', 'Response'] -# TODO: support multiple sources +# TODO: support multiple sources in a single data object # TODO: support creating Data object from array @@ -53,14 +52,17 @@ class Data: group : str or None, optional Method to group spectrum and background adaptively, these options are available so that each channel group has: - * const: `scale` number channels - * min: counts >= `scale` for src + bkg - * sig: src significance >= `scale`-sigma - * opt: optimal binning, see Kaastra & Bleeker (2016) [3]_ - * optmin: opt with counts >= `scale` for src + bkg - * optsig: opt with src significance >= `scale`-sigma - * bmin: counts >= `scale` for bkg (useful for wstat) - * bpos: bkg < 0 with probability < `scale` (useful for pgstat) + + * 'const': `scale` number channels + * 'min': counts >= `scale` for src + bkg + * 'sig': src significance >= `scale`-sigma + * 'opt': optimal binning, see Kaastra & Bleeker (2016) [3]_ + * 'optmin': opt with counts >= `scale` for src + bkg + * 'optsig': opt with src significance >= `scale`-sigma + * 'bmin': counts >= `scale` for bkg (useful for wstat) + * 'bpos': bkg < 0 with probability < `scale` (useful for pgstat) + + The default is None. scale : float or None, optional Grouping scale. Only takes effect if `group` is not None. spec_poisson : bool or None, optional @@ -74,11 +76,13 @@ class Data: ignore_bad : bool, optional Whether to ignore channels with ``QUALITY==5``. The default is True. The possible values for spectral ``QUALITY`` are + * 0: good * 1: defined bad by software * 2: defined dubious by software * 5: defined bad by user * -1: reason for bad flag unknown + record_channel : bool, optional Whether to record channel information in the label of grouped channel. Only takes effect if `group` is not None or spectral data @@ -96,12 +100,12 @@ class Data: References ---------- - .. [1] `The OGIP Spectral File Format `_ - and `Addendum: Changes log `_ + .. [1] `The OGIP Spectral File Format `__ + and `Addendum: Changes log `__ .. [2] `The Calibration Requirements for Spectral Analysis (Definition of - RMF and ARF file formats) `_ - and `Addendum: Changes log `_ - .. [3] `Kaastra & Bleeker 2016, A&A, 587, A151 `_ + RMF and ARF file formats) `__ + and `Addendum: Changes log `__ + .. [3] `Kaastra & Bleeker 2016, A&A, 587, A151 `__ """ @@ -287,14 +291,16 @@ def group(self, method: str, scale: float | int): method : str Method to group spectrum and background adaptively, these options are available so that each channel group has: - * const: `scale` number channels - * min: counts >= `scale` for src + bkg - * sig: src significance >= `scale`-sigma - * opt: optimal binning, see Kaastra & Bleeker (2016, A&A) - * optmin: opt with counts >= `scale` for src + bkg - * optsig: opt with src significance >= `scale`-sigma - * bmin: counts >= `scale` for bkg (useful for W-stat) - * bpos: bkg < 0 with probability < `scale` (useful for PG-stat) + + * 'const': `scale` number channels + * 'min': counts >= `scale` for src + bkg + * 'sig': src significance >= `scale`-sigma + * 'opt': optimal binning, see Kaastra & Bleeker (2016, A&A) + * 'optmin': opt with counts >= `scale` for src + bkg + * 'optsig': opt with src significance >= `scale`-sigma + * 'bmin': counts >= `scale` for bkg (useful for W-stat) + * 'bpos': bkg<0 with probability < `scale` (useful for PG-stat) + scale : float Grouping scale. @@ -624,8 +630,8 @@ class Spectrum: References ---------- - .. [1] `The OGIP Spectral File Format `_ - and `Addendum: Changes log `_ + .. [1] `The OGIP Spectral File Format `__ + and `Addendum: Changes log `__ """ @@ -962,8 +968,8 @@ class Response: References ---------- .. [1] `The Calibration Requirements for Spectral Analysis (Definition of - RMF and ARF file formats) `_ - and `Addendum: Changes log `_ + RMF and ARF file formats) `___ + and `Addendum: Changes log `__ """ diff --git a/src/elisa/infer/__init__.py b/src/elisa/infer/__init__.py index 1059da9f..7bcd9981 100644 --- a/src/elisa/infer/__init__.py +++ b/src/elisa/infer/__init__.py @@ -1,3 +1,4 @@ -from .fit import BayesianFit, LikelihoodFit - -__all__ = ['BayesianFit', 'LikelihoodFit'] +from .fit import ( + BayesianFit as BayesianFit, + LikelihoodFit as LikelihoodFit, +) diff --git a/src/elisa/infer/fit.py b/src/elisa/infer/fit.py index 995112f1..e600afbc 100644 --- a/src/elisa/infer/fit.py +++ b/src/elisa/infer/fit.py @@ -37,8 +37,6 @@ replace_string, ) -__all__ = ['LikelihoodFit', 'BayesianFit'] - Statistic = Literal['chi2', 'cstat', 'pstat', 'pgstat', 'wstat'] T = TypeVar('T') @@ -55,11 +53,13 @@ class BaseFit(ABC): stat : str or sequence of str The likelihood option for the data and model. Available likelihood options are: + * 'chi2' : Gaussian data * 'cstat' : Poisson data * 'pstat' : Poisson data with known background * 'pgstat' : Poisson data with Gaussian background * 'wstat' : Poisson data with Poisson background + seed : int, optional Random number generator seed. The default is 42. @@ -595,19 +595,25 @@ def mle( strategy : {0, 1, 2}, optional Optimization strategy to use in Minuit. Available options are: + * 0: Fast. * 1: Default. * 2: Careful. This improves accuracy at the cost of time. + lopt : {'minuit', 'lm'}, optional Local optimization algorithm to use. Available options are: + * 'minuit': Migrad algorithm of Minuit. * 'lm': Levenberg-Marquardt algorithm of :mod:`jaxopt`. + The default is 'minuit'. gopt : {'ns'}, optional Global optimization algorithm to find the initial guess for MLE. Available options are: + * 'ns' : nested sampling of :mod:`jaxns`. + The default is None. nboot : int, optional Number of parametric bootstrap based on the MLE. These simulation diff --git a/src/elisa/infer/likelihood.py b/src/elisa/infer/likelihood.py index 7173e77d..f96b38dd 100644 --- a/src/elisa/infer/likelihood.py +++ b/src/elisa/infer/likelihood.py @@ -41,7 +41,7 @@ def pgstat_background( References ---------- - .. [1] `XSPEC Manual Appendix B: Statistics in XSPEC `_. + .. [1] `XSPEC Manual Appendix B: Statistics in XSPEC `__. """ variance = b_err * b_err diff --git a/src/elisa/infer/nested_sampling.py b/src/elisa/infer/nested_sampling.py index 8df22a67..2a724c9c 100644 --- a/src/elisa/infer/nested_sampling.py +++ b/src/elisa/infer/nested_sampling.py @@ -23,7 +23,6 @@ from numpyro.infer.reparam import Reparam from numpyro.infer.util import _guess_max_plate_nesting, _validate_model, log_density -__all__ = ["NestedSampler"] tfpd = tfp.distributions diff --git a/src/elisa/model/__init__.py b/src/elisa/model/__init__.py index 6020e3ca..36bcf765 100644 --- a/src/elisa/model/__init__.py +++ b/src/elisa/model/__init__.py @@ -1,45 +1,17 @@ -from . import add, conv, mul from .add import * # noqa: F403 from .conv import * # noqa: F403 from .model import ( - AdditiveComponent, - AnaIntAdditive, - AnaIntMultiplicative, - ConvolutionComponent, - MultiplicativeComponent, - NumIntAdditive, - NumIntMultiplicative, + AnaIntAdditive as AnaIntAdditive, + AnaIntMultiplicative as AnaIntMultiplicative, + ConvolutionComponent as ConvolutionComponent, + NumIntAdditive as NumIntAdditive, + NumIntMultiplicative as NumIntMultiplicative, ) from .mul import * # noqa: F403 from .parameter import ( - CompositeParameter, - ConstantInterval, - ConstantValue, - Parameter, - ParameterBase, - UniformParameter, -) - -__all__ = ( - [ - 'AdditiveComponent', - 'MultiplicativeComponent', - 'ConvolutionComponent', - 'AnaIntAdditive', - 'NumIntAdditive', - 'AnaIntMultiplicative', - 'NumIntMultiplicative', - ] - + [ - 'ParameterBase', - 'Parameter', - 'UniformParameter', - 'ConstantValue', - 'ConstantInterval', - 'CompositeParameter', - # 'GPParameter', - ] - + add.__all__ - + mul.__all__ - + conv.__all__ + CompositeParameter as CompositeParameter, + ConstantInterval as ConstantInterval, + ConstantValue as ConstantValue, + Parameter as Parameter, + UniformParameter as UniformParameter, ) diff --git a/src/elisa/model/add.py b/src/elisa/model/add.py index 984890e1..4eb5b2a8 100644 --- a/src/elisa/model/add.py +++ b/src/elisa/model/add.py @@ -8,10 +8,10 @@ from elisa.util.typing import JAXArray, NameValMapping __all__ = [ - 'Blackbody', - 'BlackbodyRad', 'Band', 'BandEp', + 'Blackbody', + 'BlackbodyRad', 'Compt', 'CutoffPL', 'Gauss', @@ -21,111 +21,6 @@ ] -class Blackbody(NumIntAdditive): - r"""Blackbody function. - - .. math:: - N(E) = \frac{C K E^2}{(kT)^4 [\exp(E/kT)-1]}, - - where :math:`C=8.0525`. - - Parameters - ---------- - kT : ParameterBase, optional - The temperature :math:`kT`, in units of keV. - K : ParameterBase, optional - The amplitude :math:`K = L_{39}/D_{10}^2`, where :math:`L_{39}` is the - source luminosity in units of 10³⁹ erg s⁻¹ and :math:`D_{10}` is the - distance to the source in units of 10 kpc. - latex : str, optional - :math:`\LaTeX` format of the component. Defaults to class name. - method : {'trapz', 'simpson'}, optional - Numerical integration method. Defaults to 'trapz'. - - """ - - _config = ( - ParamConfig('kT', 'kT', 'keV', 3.0, 1e-4, 200.0), - ParamConfig('K', 'K', '10^37 erg s^-1 kpc^-2', 1.0, 1e-10, 1e10), - ) - - @staticmethod - def continnum(egrid: JAXArray, params: NameValMapping) -> JAXArray: - kT = params['kT'] - K = params['K'] - x = egrid / kT - tmp = 8.0525 * K * egrid / (kT * kT * kT) - x_ = jnp.where( - jnp.greater_equal(x, 50.0), - 1.0, # avoid exponential overflow - x, - ) - - return jnp.where( - jnp.less_equal(x, 1e-4), - tmp, - jnp.where( - jnp.greater_equal(x, 50.0), - 0.0, # avoid exponential overflow - tmp * x / jnp.expm1(x_), - ), - ) - # return 8.0525 * K * e*e / (kT*kT*kT*kT * jnp.expm1(e / kT)) - - -class BlackbodyRad(NumIntAdditive): - r"""Blackbody function with normalization proportional to the surface area. - - .. math:: - N(E) = \frac{C K E^2}{\exp(E/kT)-1}, - - where :math:`C=1.0344 \times 10^{-3}` cm⁻² s⁻¹ keV⁻³. - - Parameters - ---------- - kT : ParameterBase, optional - The temperature :math:`kT`, in units of keV. - K : ParameterBase, optional - The amplitude :math:`K = R_\mathrm{km}^2/D_{10}^2`, where - :math:`R_\mathrm{km}` is the source radius in km and :math:`D_{10}` is - the distance to the source in units of 10 kpc. - latex : str, optional - :math:`\LaTeX` format of the component. Defaults to class name. - method : {'trapz', 'simpson'}, optional - Numerical integration method. Defaults to 'trapz'. - - """ - - _config = ( - ParamConfig('kT', 'kT', 'keV', 3.0, 1e-4, 200.0), - ParamConfig('K', 'K', '', 1.0, 1e-10, 1e10), - ) - - @staticmethod - def continnum(egrid: JAXArray, params: NameValMapping) -> JAXArray: - kT = params['kT'] - K = params['K'] - - x = egrid / kT - tmp = 1.0344e-3 * K * egrid - x_ = jnp.where( - jnp.greater_equal(x, 50.0), - 1.0, # avoid exponential overflow - x, - ) - - return jnp.where( - jnp.less_equal(x, 1e-4), - tmp * kT, - jnp.where( - jnp.greater_equal(x, 50.0), - 0.0, # avoid exponential overflow - tmp * egrid / jnp.expm1(x_), - ), - ) - # return 1.0344e-3 * K * e*e / jnp.expm1(e / kT) - - class Band(NumIntAdditive): r"""Gamma-ray burst continuum developed by Band et al. (1993) [1]_. @@ -161,7 +56,7 @@ class Band(NumIntAdditive): References ---------- .. [1] `Band, D., et al. 1993, ApJ, 413, 281 - `_ + `__ """ @@ -234,7 +129,7 @@ class BandEp(NumIntAdditive): References ---------- .. [1] `Band, D., et al. 1993, ApJ, 413, 281 - `_ + `__ """ @@ -271,6 +166,111 @@ def continnum(egrid: JAXArray, params: NameValMapping) -> JAXArray: return K * jnp.exp(log) +class Blackbody(NumIntAdditive): + r"""Blackbody function. + + .. math:: + N(E) = \frac{C K E^2}{(kT)^4 [\exp(E/kT)-1]}, + + where :math:`C=8.0525`. + + Parameters + ---------- + kT : ParameterBase, optional + The temperature :math:`kT`, in units of keV. + K : ParameterBase, optional + The amplitude :math:`K = L_{39}/D_{10}^2`, where :math:`L_{39}` is the + source luminosity in units of 10³⁹ erg s⁻¹ and :math:`D_{10}` is the + distance to the source in units of 10 kpc. + latex : str, optional + :math:`\LaTeX` format of the component. Defaults to class name. + method : {'trapz', 'simpson'}, optional + Numerical integration method. Defaults to 'trapz'. + + """ + + _config = ( + ParamConfig('kT', 'kT', 'keV', 3.0, 1e-4, 200.0), + ParamConfig('K', 'K', '10^37 erg s^-1 kpc^-2', 1.0, 1e-10, 1e10), + ) + + @staticmethod + def continnum(egrid: JAXArray, params: NameValMapping) -> JAXArray: + kT = params['kT'] + K = params['K'] + x = egrid / kT + tmp = 8.0525 * K * egrid / (kT * kT * kT) + x_ = jnp.where( + jnp.greater_equal(x, 50.0), + 1.0, # avoid exponential overflow + x, + ) + + return jnp.where( + jnp.less_equal(x, 1e-4), + tmp, + jnp.where( + jnp.greater_equal(x, 50.0), + 0.0, # avoid exponential overflow + tmp * x / jnp.expm1(x_), + ), + ) + # return 8.0525 * K * e*e / (kT*kT*kT*kT * jnp.expm1(e / kT)) + + +class BlackbodyRad(NumIntAdditive): + r"""Blackbody function with normalization proportional to the surface area. + + .. math:: + N(E) = \frac{C K E^2}{\exp(E/kT)-1}, + + where :math:`C=1.0344 \times 10^{-3}` cm⁻² s⁻¹ keV⁻³. + + Parameters + ---------- + kT : ParameterBase, optional + The temperature :math:`kT`, in units of keV. + K : ParameterBase, optional + The amplitude :math:`K = R_\mathrm{km}^2/D_{10}^2`, where + :math:`R_\mathrm{km}` is the source radius in km and :math:`D_{10}` is + the distance to the source in units of 10 kpc. + latex : str, optional + :math:`\LaTeX` format of the component. Defaults to class name. + method : {'trapz', 'simpson'}, optional + Numerical integration method. Defaults to 'trapz'. + + """ + + _config = ( + ParamConfig('kT', 'kT', 'keV', 3.0, 1e-4, 200.0), + ParamConfig('K', 'K', '', 1.0, 1e-10, 1e10), + ) + + @staticmethod + def continnum(egrid: JAXArray, params: NameValMapping) -> JAXArray: + kT = params['kT'] + K = params['K'] + + x = egrid / kT + tmp = 1.0344e-3 * K * egrid + x_ = jnp.where( + jnp.greater_equal(x, 50.0), + 1.0, # avoid exponential overflow + x, + ) + + return jnp.where( + jnp.less_equal(x, 1e-4), + tmp * kT, + jnp.where( + jnp.greater_equal(x, 50.0), + 0.0, # avoid exponential overflow + tmp * egrid / jnp.expm1(x_), + ), + ) + # return 1.0344e-3 * K * e*e / jnp.expm1(e / kT) + + class BrokenPL(AnaIntAdditive): pass @@ -467,7 +467,7 @@ class OTTS(NumIntAdditive): References ---------- .. [1] `Liang, E. P., et al., 1983, ApJ, 271, 776 - `_ + `__ """ diff --git a/src/elisa/model/parameter.py b/src/elisa/model/parameter.py index e6e870cc..281d6af9 100644 --- a/src/elisa/model/parameter.py +++ b/src/elisa/model/parameter.py @@ -652,12 +652,14 @@ class ConstantInterval(ConstantParameter): method : {'quadgk', 'quadcc', 'quadts', 'romberg', 'rombergts'}, optional Numerical integration method used to integrate over the parameter. Available options are: + * 'quadgk' : global adaptive quadrature with Gauss-Konrod rule * 'quadcc' : global adaptive quadrature with Clenshaw-Curtis rule * 'quadts' : global adaptive quadrature with trapz tanh-sinh rule * 'romberg' : Romberg integration * 'rombergts' : Romberg integration with tanh-sinh (a.k.a. double exponential) transformation + The default is 'quadgk'. latex : str, optional :math:`\LaTeX` format of the parameter. The default is as `name`. @@ -666,8 +668,7 @@ class ConstantInterval(ConstantParameter): References ---------- - .. [1] `quadax docs `_ + .. [1] `quadax docs `__ """ @@ -744,7 +745,7 @@ class CompositeParameter(ParameterBase): Parameters ---------- - params : Parameter, or sequence of Parameter + params : ParameterBase, or sequence of ParameterBase Parameters to be composed. op : callable Function to be applied to `params`. diff --git a/src/elisa/util/__init__.py b/src/elisa/util/__init__.py index 23613963..4ca12d01 100644 --- a/src/elisa/util/__init__.py +++ b/src/elisa/util/__init__.py @@ -1,3 +1,4 @@ -from .config import jax_enable_x64, set_cpu_cores - -__all__ = ['jax_enable_x64', 'set_cpu_cores'] +from .config import ( + jax_enable_x64 as jax_enable_x64, + set_cpu_cores as set_cpu_cores, +) diff --git a/src/elisa/util/config.py b/src/elisa/util/config.py index 66fbe0ff..3e82a0f3 100644 --- a/src/elisa/util/config.py +++ b/src/elisa/util/config.py @@ -6,8 +6,6 @@ from numpyro import enable_x64, set_host_device_count -__all__ = ['jax_enable_x64', 'set_cpu_cores'] - def jax_enable_x64(use_x64: bool) -> None: """Changes the default float precision of arrays in JAX. diff --git a/src/elisa/util/integrate.py b/src/elisa/util/integrate.py index 4c4eaa91..8a84ecd8 100644 --- a/src/elisa/util/integrate.py +++ b/src/elisa/util/integrate.py @@ -41,12 +41,14 @@ def make_integral_factory( method : {'quadgk', 'quadcc', 'quadts', 'romberg', 'rombergts'}, optional Numerical integration method used to integrate over the parameter. Available options are: + * 'quadgk' : global adaptive quadrature with Gauss-Konrod rule * 'quadcc' : global adaptive quadrature with Clenshaw-Curtis rule * 'quadts' : global adaptive quadrature with trapz tanh-sinh rule * 'romberg' : Romberg integration * 'rombergts' : Romberg integration with tanh-sinh (a.k.a. double exponential) transformation + The default is 'quadgk'. kwargs : dict, optional Extra kwargs passed to integration methods. See [1]_ for details. @@ -59,8 +61,7 @@ def make_integral_factory( References ---------- - .. [1] `quadax docs `_ + .. [1] `quadax docs `__ """ if method not in _QUAD_FN: diff --git a/src/elisa/util/typing.py b/src/elisa/util/typing.py index 978afc9a..fdeae994 100644 --- a/src/elisa/util/typing.py +++ b/src/elisa/util/typing.py @@ -5,6 +5,33 @@ from jax import Array from jax.typing import ArrayLike +__all__ = [ + 'PyFloat', + 'JAXFloat', + 'Float', + 'PRNGKey', + 'NumPyArray', + 'JAXArray', + 'Array', + 'ArrayLike', + 'CompID', + 'CompName', + 'CompParamName', + 'ParamID', + 'ParamName', + 'NameValMapping', + 'CompIDParamValMapping', + 'CompIDStrMapping', + 'ParamIDStrMapping', + 'ParamIDValMapping', + 'CompEval', + 'ConvolveEval', + 'ModelEval', + 'ModelCompiledFn', + 'NameLaTeX', + 'AdditiveFn', +] + T = TypeVar('T') PyFloat = Union[float, np.inexact] # must include 0-d NDArray with float dtype