From be7efbec1a9488792bec7acd59b048ffa4576a50 Mon Sep 17 00:00:00 2001 From: sdupourque Date: Thu, 19 Sep 2024 14:22:01 +0200 Subject: [PATCH 1/7] prior predictive coverage --- src/jaxspec/analysis/_plot.py | 35 ++++++++++++++++ src/jaxspec/fit.py | 78 +++++++++++++++++++++++++++++++++++ tests/test_mcmc.py | 7 +++- 3 files changed, 119 insertions(+), 1 deletion(-) create mode 100644 src/jaxspec/analysis/_plot.py diff --git a/src/jaxspec/analysis/_plot.py b/src/jaxspec/analysis/_plot.py new file mode 100644 index 0000000..6b222fd --- /dev/null +++ b/src/jaxspec/analysis/_plot.py @@ -0,0 +1,35 @@ +import matplotlib.pyplot as plt +import numpy as np + +from jax.typing import ArrayLike +from scipy.stats import nbinom + + +def _plot_poisson_data_with_error( + ax: plt.Axes, + x_bins: ArrayLike, + y: ArrayLike, + percentiles: tuple = (16, 84), +): + """ + Plot Poisson data with error bars. We extrapolate the intrinsic error of the observation assuming a prior rate + distributed according to a Gamma RV. + """ + y_low = nbinom.ppf(percentiles[0] / 100, y, 0.5) + y_high = nbinom.ppf(percentiles[1] / 100, y, 0.5) + + ax_to_plot = ax.errorbar( + np.sqrt(x_bins[0] * x_bins[1]), + y, + xerr=np.abs(x_bins - np.sqrt(x_bins[0] * x_bins[1])), + yerr=[ + y - y_low, + y_high - y, + ], + color="black", + linestyle="none", + alpha=0.3, + capsize=2, + ) + + return ax_to_plot diff --git a/src/jaxspec/fit.py b/src/jaxspec/fit.py index d0ebb12..358f8b2 100644 --- a/src/jaxspec/fit.py +++ b/src/jaxspec/fit.py @@ -10,6 +10,8 @@ import haiku as hk import jax import jax.numpy as jnp +import matplotlib.pyplot as plt +import numpy as np import numpyro from jax import random @@ -24,6 +26,7 @@ from numpyro.infer.reparam import TransformReparam from numpyro.infer.util import log_density +from .analysis._plot import _plot_poisson_data_with_error from .analysis.results import FitResult from .data import ObsConfiguration from .model.abc import SpectralModel @@ -287,6 +290,16 @@ def parameter_names(self) -> list[str]: observed_sites = relations["observed"] return [site for site in all_sites if site not in observed_sites] + @cached_property + def observation_names(self) -> list[str]: + """ + List of the observations. + """ + relations = get_model_relations(self.numpyro_model) + all_sites = relations["sample_sample"].keys() + observed_sites = relations["observed"] + return [site for site in all_sites if site in observed_sites] + def array_to_dict(self, theta): """ Convert an array of parameters to a dictionary of parameters. @@ -323,6 +336,71 @@ def get_initial_params(self, key: PRNGKey = PRNGKey(0), num_samples: int = 1): self.numpyro_model, return_sites=self.parameter_names, num_samples=num_samples )(key, observed=False) + def prior_predictive_coverage( + self, key: PRNGKey = PRNGKey(0), num_samples: int = 1000, percentiles: tuple = (16, 84) + ): + """ + Check if the prior distribution include the observed data. + """ + key_prior, key_posterior = jax.random.split(key, 2) + prior_params = self.get_initial_params(key=key_prior, num_samples=num_samples) + posterior_observations = Predictive( + self.numpyro_model, + return_sites=self.observation_names, + num_samples=num_samples, + posterior_samples=prior_params, + )(key, observed=False) + + for key, value in self.observation_container.items(): + fig, axs = plt.subplots( + nrows=2, ncols=1, sharex=True, figsize=(8, 8), height_ratios=[3, 1] + ) + + _plot_poisson_data_with_error( + axs[0], + value.out_energies, + value.folded_counts.values, + percentiles=percentiles, + ) + + axs[0].stairs( + np.max(posterior_observations["obs_" + key], axis=0), + edges=[*list(value.out_energies[0]), value.out_energies[1][-1]], + baseline=np.min(posterior_observations["obs_" + key], axis=0), + alpha=0.3, + fill=True, + color=(0.15, 0.25, 0.45), + ) + + # rank = np.vstack((posterior_observations["obs_" + key], value.folded_counts.values)).argsort(axis=0)[-1] / (num_samples) * 100 + counts = posterior_observations["obs_" + key] + observed = value.folded_counts.values + + num_samples = counts.shape[0] + + less_than_obs = (counts < observed).sum(axis=0) + equal_to_obs = (counts == observed).sum(axis=0) + + rank = (less_than_obs + 0.5 * equal_to_obs) / num_samples * 100 + + axs[1].stairs(rank, edges=[*list(value.out_energies[0]), value.out_energies[1][-1]]) + + axs[1].plot( + (value.out_energies.min(), value.out_energies.max()), + (50, 50), + color="black", + linestyle="--", + ) + + axs[1].set_xlabel("Energy (keV)") + axs[0].set_ylabel("Counts") + axs[1].set_ylabel("Rank (%)") + axs[1].set_ylim(0, 100) + axs[0].set_xlim(value.out_energies.min(), value.out_energies.max()) + axs[0].loglog() + plt.suptitle(f"Prior Predictive coverage for {key}") + plt.show() + class BayesianModelFitter(BayesianModel, ABC): def build_inference_data( diff --git a/tests/test_mcmc.py b/tests/test_mcmc.py index 2a5bd3b..baa53f7 100644 --- a/tests/test_mcmc.py +++ b/tests/test_mcmc.py @@ -1,4 +1,4 @@ -from jaxspec.fit import NSFitter +from jaxspec.fit import BayesianModel, NSFitter def test_convergence(get_individual_mcmc_results, get_joint_mcmc_result): @@ -12,3 +12,8 @@ def test_ns(obs_model_prior): obsconf = obsconfs[0] fitter = NSFitter(model, prior, obsconf) fitter.fit(num_samples=5000, num_live_points=200) + + +def test_prior_predictive_coverage(obs_model_prior): + obsconfs, model, prior = obs_model_prior + BayesianModel(model, prior, obsconfs).prior_predictive_coverage() From c1001b43c26bffafa7d0250516cf8817c62f5dbd Mon Sep 17 00:00:00 2001 From: sdupourque Date: Tue, 24 Sep 2024 11:05:38 +0200 Subject: [PATCH 2/7] saving results --- docs/faq/cookbook.md | 27 +++++++++++++++++++++++++-- src/jaxspec/fit.py | 34 ++++++++++++++++++++++------------ 2 files changed, 47 insertions(+), 14 deletions(-) diff --git a/docs/faq/cookbook.md b/docs/faq/cookbook.md index cae7b5f..05569d4 100644 --- a/docs/faq/cookbook.md +++ b/docs/faq/cookbook.md @@ -45,7 +45,7 @@ result = fitter.fit(num_samples=1_000) You should look at [`SpectralModel.photon_flux`][jaxspec.model.abc.SpectralModel.photon_flux] and [`SpectralModel.energy_flux`][jaxspec.model.abc.SpectralModel.energy_flux] methods. -``` python +```python import jax.numpy as jnp import matplotlib.pyplot as plt from jaxspec.model.additive import Blackbodyrad @@ -73,4 +73,27 @@ energy_flux = spectral_model.energy_flux(params, energies[:-1], energies[1:], n_ You should look at [`FitResult.photon_flux`][jaxspec.analysis.results.FitResult.photon_flux], [`FitResult.energy_flux`][jaxspec.analysis.results.FitResult.energy_flux], and -[`FitResult.luminosity`][jaxspec.analysis.results.FitResult.luminosity] \ No newline at end of file +[`FitResult.luminosity`][jaxspec.analysis.results.FitResult.luminosity] + +## Save and load inference results + +You can use the `dill` package to serialise and un-serialise such objects. First you should install it using `pip` + +``` +pip install dill +``` + +Then use the following lines to save and load the files: + +```python +import dill + +# Save the results + +with open(r"result.pickle", "wb") as output_file: + dill.dump(result, output_file) + +# Load the results +with open(r"result.pickle", "rb") as input_file: + result_pickled = dill.load(input_file) +``` \ No newline at end of file diff --git a/src/jaxspec/fit.py b/src/jaxspec/fit.py index 358f8b2..df88794 100644 --- a/src/jaxspec/fit.py +++ b/src/jaxspec/fit.py @@ -171,7 +171,7 @@ def prior_distributions_func(): prior_distributions_func = prior_distributions self.prior_distributions_func = prior_distributions_func - self.init_params = self.get_initial_params() + self.init_params = self.prior_samples() @cached_property def observation_container(self) -> dict[str, ObsConfiguration]: @@ -323,7 +323,7 @@ def dict_to_array(self, dict_of_params): return theta - def get_initial_params(self, key: PRNGKey = PRNGKey(0), num_samples: int = 1): + def prior_samples(self, key: PRNGKey = PRNGKey(0), num_samples: int = 100): """ Get initial parameters for the model by sampling from the prior distribution @@ -332,9 +332,24 @@ def get_initial_params(self, key: PRNGKey = PRNGKey(0), num_samples: int = 1): num_samples: the number of samples to draw from the prior. """ - return Predictive( - self.numpyro_model, return_sites=self.parameter_names, num_samples=num_samples - )(key, observed=False) + @jax.jit + def prior_sample(key): + return Predictive( + self.numpyro_model, return_sites=self.parameter_names, num_samples=num_samples + )(key, observed=False) + + return prior_sample(key) + + def mock_observations(self, parameters, key: PRNGKey = PRNGKey(0)): + @jax.jit + def fakeit(key, parameters): + return Predictive( + self.numpyro_model, + return_sites=self.observation_names, + posterior_samples=parameters, + )(key, observed=False) + + return fakeit(key, parameters) def prior_predictive_coverage( self, key: PRNGKey = PRNGKey(0), num_samples: int = 1000, percentiles: tuple = (16, 84) @@ -343,13 +358,8 @@ def prior_predictive_coverage( Check if the prior distribution include the observed data. """ key_prior, key_posterior = jax.random.split(key, 2) - prior_params = self.get_initial_params(key=key_prior, num_samples=num_samples) - posterior_observations = Predictive( - self.numpyro_model, - return_sites=self.observation_names, - num_samples=num_samples, - posterior_samples=prior_params, - )(key, observed=False) + prior_params = self.prior_samples(key=key_prior, num_samples=num_samples) + posterior_observations = self.mock_observations(prior_params, key=key_posterior) for key, value in self.observation_container.items(): fig, axs = plt.subplots( From b173c32590b21eb296c33a7c2bff83d018aaebea Mon Sep 17 00:00:00 2001 From: sdupourque Date: Wed, 25 Sep 2024 14:23:35 +0200 Subject: [PATCH 3/7] update dependancies --- pyproject.toml | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index ef09fe1..0c0e04d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -11,13 +11,13 @@ documentation = "https://jaxspec.readthedocs.io/en/latest/" [tool.poetry.dependencies] python = ">=3.10,<3.12" -jax = "^0.4.30" +jax = "^0.4.33" jaxlib = "^0.4.30" numpy = "<2.0.0" pandas = "^2.2.0" astropy = "^6.0.0" -numpyro = "^0.15.2" -dm-haiku = ">=0.0.11,<0.0.13" +numpyro = "^0.15.3" +dm-haiku = "^0.0.12" networkx = "^3.1" matplotlib = "^3.8.0" arviz = ">=0.17.1,<0.20.0" @@ -28,13 +28,12 @@ gpjax = "^0.8.0" jaxopt = "^0.8.1" tinygp = "^0.3.0" seaborn = "^0.13.1" -mkdocstrings = ">=0.24,<0.27" sparse = "^0.15.1" optimistix = "^0.0.7" scipy = "<1.15" mendeleev = ">=0.15,<0.18" pyzmq = "<27" -jaxns = "^2.5.1" +jaxns = "^2.6.1" pooch = "^1.8.2" interpax = "^0.3.3" watermark = "^2.4.3" From f7d9dc99d7fc11ce52d42696d0df1818d08da366 Mon Sep 17 00:00:00 2001 From: sdupourque Date: Wed, 25 Sep 2024 14:42:27 +0200 Subject: [PATCH 4/7] fixing jaxns issues --- docs/faq/cookbook.md | 1 - pyproject.toml | 2 +- 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/docs/faq/cookbook.md b/docs/faq/cookbook.md index 05569d4..2f8569f 100644 --- a/docs/faq/cookbook.md +++ b/docs/faq/cookbook.md @@ -89,7 +89,6 @@ Then use the following lines to save and load the files: import dill # Save the results - with open(r"result.pickle", "wb") as output_file: dill.dump(result, output_file) diff --git a/pyproject.toml b/pyproject.toml index 0c0e04d..45da04c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -33,7 +33,7 @@ optimistix = "^0.0.7" scipy = "<1.15" mendeleev = ">=0.15,<0.18" pyzmq = "<27" -jaxns = "^2.6.1" +jaxns = "<2.6" pooch = "^1.8.2" interpax = "^0.3.3" watermark = "^2.4.3" From 1b41222da137b2a4ef9f3a7fc6d159d53306cd74 Mon Sep 17 00:00:00 2001 From: sdupourque Date: Wed, 25 Sep 2024 15:28:02 +0200 Subject: [PATCH 5/7] pretty rendering for spectral model parameters --- src/jaxspec/model/abc.py | 26 ++++++++++++++++++++++++++ 1 file changed, 26 insertions(+) diff --git a/src/jaxspec/model/abc.py b/src/jaxspec/model/abc.py index 23a68cb..9f6fe12 100644 --- a/src/jaxspec/model/abc.py +++ b/src/jaxspec/model/abc.py @@ -7,9 +7,11 @@ import jax import jax.numpy as jnp import networkx as nx +import rich from haiku._src import base from jax.scipy.integrate import trapezoid +from rich.table import Table from simpleeval import simple_eval @@ -110,6 +112,30 @@ def func_to_transform(e_low, e_high, n_points=2): def params(self): return self.transformed_func_photon.init(None, jnp.ones(10), jnp.ones(10)) + def __rich_repr__(self): + table = Table(title=str(self)) + + table.add_column("Component", justify="right", style="bold", no_wrap=True) + table.add_column("Parameter") + + params = self.params + + for component in params.keys(): + once = True + + for parameters in params[component].keys(): + table.add_row(component if once else "", parameters) + once = False + + return table + + def __repr_html_(self): + return self.__rich_repr__() + + def __repr__(self): + rich.print(self.__rich_repr__()) + return "" + def photon_flux(self, params, e_low, e_high, n_points=2): r""" Compute the expected counts between $E_\min$ and $E_\max$ by integrating the model. From dfbe94fb53353e998abb68c05319ac999625526c Mon Sep 17 00:00:00 2001 From: sdupourque Date: Wed, 25 Sep 2024 15:32:04 +0200 Subject: [PATCH 6/7] add test repr --- tests/test_repr.py | 3 +++ 1 file changed, 3 insertions(+) create mode 100644 tests/test_repr.py diff --git a/tests/test_repr.py b/tests/test_repr.py new file mode 100644 index 0000000..e8b7e22 --- /dev/null +++ b/tests/test_repr.py @@ -0,0 +1,3 @@ +def test_gp_bkg(obs_model_prior): + _, model, _ = obs_model_prior + print(repr(model)) From 924d6161594813576ce90d562bc2135ee66144c5 Mon Sep 17 00:00:00 2001 From: sdupourque Date: Wed, 25 Sep 2024 15:44:30 +0200 Subject: [PATCH 7/7] fixing bad docs --- .readthedocs.yaml | 2 +- pyproject.toml | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/.readthedocs.yaml b/.readthedocs.yaml index 3b44c0c..96778b1 100644 --- a/.readthedocs.yaml +++ b/.readthedocs.yaml @@ -24,7 +24,7 @@ build: - poetry install --with docs # Using insiders versions of mkdocs-material & mkdocstrings - pip uninstall mkdocs-material mkdocstrings mkdocstrings-python -y - - pip install git+https://$GH_TOKEN@github.com/squidfunk/mkdocs-material-insiders.git + - pip install git+https://$GH_TOKEN@github.com/squidfunk/mkdocs-material-insiders.git@9.5.36-insiders-4.53.13 - pip install git+https://$GH_TOKEN@github.com/pawamoy-insiders/mkdocstrings-python.git - pip install mkdocs-jupyter # This is bugged, I enforced it manually, let's see if it works diff --git a/pyproject.toml b/pyproject.toml index 45da04c..643acf4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -40,10 +40,10 @@ watermark = "^2.4.3" [tool.poetry.group.docs.dependencies] -mkdocs = "^1.5.3" +mkdocs = "^1.6.1" mkdocs-material = "^9.4.6" mkdocstrings = {extras = ["python"], version = ">=0.24,<0.27"} -mkdocs-jupyter = ">=0.24.6,<0.26.0" +mkdocs-jupyter = "^0.25.0" [tool.poetry.group.test.dependencies]