Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Prior features #188

Merged
merged 8 commits into from
Sep 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .readthedocs.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ build:
- poetry install --with docs
# Using insiders versions of mkdocs-material & mkdocstrings
- pip uninstall mkdocs-material mkdocstrings mkdocstrings-python -y
- pip install git+https://$GH_TOKEN@github.com/squidfunk/mkdocs-material-insiders.git
- pip install git+https://$GH_TOKEN@github.com/squidfunk/mkdocs-material-insiders.git@9.5.36-insiders-4.53.13
- pip install git+https://$GH_TOKEN@github.com/pawamoy-insiders/mkdocstrings-python.git
- pip install mkdocs-jupyter # This is bugged, I enforced it manually, let's see if it works

Expand Down
26 changes: 24 additions & 2 deletions docs/faq/cookbook.md
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ result = fitter.fit(num_samples=1_000)
You should look at [`SpectralModel.photon_flux`][jaxspec.model.abc.SpectralModel.photon_flux] and
[`SpectralModel.energy_flux`][jaxspec.model.abc.SpectralModel.energy_flux] methods.

``` python
```python
import jax.numpy as jnp
import matplotlib.pyplot as plt
from jaxspec.model.additive import Blackbodyrad
Expand Down Expand Up @@ -73,4 +73,26 @@ energy_flux = spectral_model.energy_flux(params, energies[:-1], energies[1:], n_

You should look at [`FitResult.photon_flux`][jaxspec.analysis.results.FitResult.photon_flux],
[`FitResult.energy_flux`][jaxspec.analysis.results.FitResult.energy_flux], and
[`FitResult.luminosity`][jaxspec.analysis.results.FitResult.luminosity]
[`FitResult.luminosity`][jaxspec.analysis.results.FitResult.luminosity]

## Save and load inference results

You can use the `dill` package to serialise and un-serialise such objects. First you should install it using `pip`

```
pip install dill
```

Then use the following lines to save and load the files:

```python
import dill

# Save the results
with open(r"result.pickle", "wb") as output_file:
dill.dump(result, output_file)

# Load the results
with open(r"result.pickle", "rb") as input_file:
result_pickled = dill.load(input_file)
```
13 changes: 6 additions & 7 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,13 @@ documentation = "https://jaxspec.readthedocs.io/en/latest/"

[tool.poetry.dependencies]
python = ">=3.10,<3.12"
jax = "^0.4.30"
jax = "^0.4.33"
jaxlib = "^0.4.30"
numpy = "<2.0.0"
pandas = "^2.2.0"
astropy = "^6.0.0"
numpyro = "^0.15.2"
dm-haiku = ">=0.0.11,<0.0.13"
numpyro = "^0.15.3"
dm-haiku = "^0.0.12"
networkx = "^3.1"
matplotlib = "^3.8.0"
arviz = ">=0.17.1,<0.20.0"
Expand All @@ -28,23 +28,22 @@ gpjax = "^0.8.0"
jaxopt = "^0.8.1"
tinygp = "^0.3.0"
seaborn = "^0.13.1"
mkdocstrings = ">=0.24,<0.27"
sparse = "^0.15.1"
optimistix = "^0.0.7"
scipy = "<1.15"
mendeleev = ">=0.15,<0.18"
pyzmq = "<27"
jaxns = "^2.5.1"
jaxns = "<2.6"
pooch = "^1.8.2"
interpax = "^0.3.3"
watermark = "^2.4.3"


[tool.poetry.group.docs.dependencies]
mkdocs = "^1.5.3"
mkdocs = "^1.6.1"
mkdocs-material = "^9.4.6"
mkdocstrings = {extras = ["python"], version = ">=0.24,<0.27"}
mkdocs-jupyter = ">=0.24.6,<0.26.0"
mkdocs-jupyter = "^0.25.0"


[tool.poetry.group.test.dependencies]
Expand Down
35 changes: 35 additions & 0 deletions src/jaxspec/analysis/_plot.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
import matplotlib.pyplot as plt
import numpy as np

from jax.typing import ArrayLike
from scipy.stats import nbinom


def _plot_poisson_data_with_error(
ax: plt.Axes,
x_bins: ArrayLike,
y: ArrayLike,
percentiles: tuple = (16, 84),
):
"""
Plot Poisson data with error bars. We extrapolate the intrinsic error of the observation assuming a prior rate
distributed according to a Gamma RV.
"""
y_low = nbinom.ppf(percentiles[0] / 100, y, 0.5)
y_high = nbinom.ppf(percentiles[1] / 100, y, 0.5)

ax_to_plot = ax.errorbar(
np.sqrt(x_bins[0] * x_bins[1]),
y,
xerr=np.abs(x_bins - np.sqrt(x_bins[0] * x_bins[1])),
yerr=[
y - y_low,
y_high - y,
],
color="black",
linestyle="none",
alpha=0.3,
capsize=2,
)

return ax_to_plot
98 changes: 93 additions & 5 deletions src/jaxspec/fit.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
import haiku as hk
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import numpy as np
import numpyro

from jax import random
Expand All @@ -24,6 +26,7 @@
from numpyro.infer.reparam import TransformReparam
from numpyro.infer.util import log_density

from .analysis._plot import _plot_poisson_data_with_error
from .analysis.results import FitResult
from .data import ObsConfiguration
from .model.abc import SpectralModel
Expand Down Expand Up @@ -168,7 +171,7 @@ def prior_distributions_func():
prior_distributions_func = prior_distributions

self.prior_distributions_func = prior_distributions_func
self.init_params = self.get_initial_params()
self.init_params = self.prior_samples()

@cached_property
def observation_container(self) -> dict[str, ObsConfiguration]:
Expand Down Expand Up @@ -287,6 +290,16 @@ def parameter_names(self) -> list[str]:
observed_sites = relations["observed"]
return [site for site in all_sites if site not in observed_sites]

@cached_property
def observation_names(self) -> list[str]:
"""
List of the observations.
"""
relations = get_model_relations(self.numpyro_model)
all_sites = relations["sample_sample"].keys()
observed_sites = relations["observed"]
return [site for site in all_sites if site in observed_sites]

def array_to_dict(self, theta):
"""
Convert an array of parameters to a dictionary of parameters.
Expand All @@ -310,7 +323,7 @@ def dict_to_array(self, dict_of_params):

return theta

def get_initial_params(self, key: PRNGKey = PRNGKey(0), num_samples: int = 1):
def prior_samples(self, key: PRNGKey = PRNGKey(0), num_samples: int = 100):
"""
Get initial parameters for the model by sampling from the prior distribution

Expand All @@ -319,9 +332,84 @@ def get_initial_params(self, key: PRNGKey = PRNGKey(0), num_samples: int = 1):
num_samples: the number of samples to draw from the prior.
"""

return Predictive(
self.numpyro_model, return_sites=self.parameter_names, num_samples=num_samples
)(key, observed=False)
@jax.jit
def prior_sample(key):
return Predictive(
self.numpyro_model, return_sites=self.parameter_names, num_samples=num_samples
)(key, observed=False)

return prior_sample(key)

def mock_observations(self, parameters, key: PRNGKey = PRNGKey(0)):
@jax.jit
def fakeit(key, parameters):
return Predictive(
self.numpyro_model,
return_sites=self.observation_names,
posterior_samples=parameters,
)(key, observed=False)

return fakeit(key, parameters)

def prior_predictive_coverage(
self, key: PRNGKey = PRNGKey(0), num_samples: int = 1000, percentiles: tuple = (16, 84)
):
"""
Check if the prior distribution include the observed data.
"""
key_prior, key_posterior = jax.random.split(key, 2)
prior_params = self.prior_samples(key=key_prior, num_samples=num_samples)
posterior_observations = self.mock_observations(prior_params, key=key_posterior)

for key, value in self.observation_container.items():
fig, axs = plt.subplots(
nrows=2, ncols=1, sharex=True, figsize=(8, 8), height_ratios=[3, 1]
)

_plot_poisson_data_with_error(
axs[0],
value.out_energies,
value.folded_counts.values,
percentiles=percentiles,
)

axs[0].stairs(
np.max(posterior_observations["obs_" + key], axis=0),
edges=[*list(value.out_energies[0]), value.out_energies[1][-1]],
baseline=np.min(posterior_observations["obs_" + key], axis=0),
alpha=0.3,
fill=True,
color=(0.15, 0.25, 0.45),
)

# rank = np.vstack((posterior_observations["obs_" + key], value.folded_counts.values)).argsort(axis=0)[-1] / (num_samples) * 100
counts = posterior_observations["obs_" + key]
observed = value.folded_counts.values

num_samples = counts.shape[0]

less_than_obs = (counts < observed).sum(axis=0)
equal_to_obs = (counts == observed).sum(axis=0)

rank = (less_than_obs + 0.5 * equal_to_obs) / num_samples * 100

axs[1].stairs(rank, edges=[*list(value.out_energies[0]), value.out_energies[1][-1]])

axs[1].plot(
(value.out_energies.min(), value.out_energies.max()),
(50, 50),
color="black",
linestyle="--",
)

axs[1].set_xlabel("Energy (keV)")
axs[0].set_ylabel("Counts")
axs[1].set_ylabel("Rank (%)")
axs[1].set_ylim(0, 100)
axs[0].set_xlim(value.out_energies.min(), value.out_energies.max())
axs[0].loglog()
plt.suptitle(f"Prior Predictive coverage for {key}")
plt.show()


class BayesianModelFitter(BayesianModel, ABC):
Expand Down
26 changes: 26 additions & 0 deletions src/jaxspec/model/abc.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,11 @@
import jax
import jax.numpy as jnp
import networkx as nx
import rich

from haiku._src import base
from jax.scipy.integrate import trapezoid
from rich.table import Table
from simpleeval import simple_eval


Expand Down Expand Up @@ -110,6 +112,30 @@
def params(self):
return self.transformed_func_photon.init(None, jnp.ones(10), jnp.ones(10))

def __rich_repr__(self):
table = Table(title=str(self))

table.add_column("Component", justify="right", style="bold", no_wrap=True)
table.add_column("Parameter")

params = self.params

for component in params.keys():
once = True

for parameters in params[component].keys():
table.add_row(component if once else "", parameters)
once = False

return table

def __repr_html_(self):
return self.__rich_repr__()

Check warning on line 133 in src/jaxspec/model/abc.py

View check run for this annotation

Codecov / codecov/patch

src/jaxspec/model/abc.py#L133

Added line #L133 was not covered by tests

def __repr__(self):
rich.print(self.__rich_repr__())
return ""

def photon_flux(self, params, e_low, e_high, n_points=2):
r"""
Compute the expected counts between $E_\min$ and $E_\max$ by integrating the model.
Expand Down
7 changes: 6 additions & 1 deletion tests/test_mcmc.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from jaxspec.fit import NSFitter
from jaxspec.fit import BayesianModel, NSFitter


def test_convergence(get_individual_mcmc_results, get_joint_mcmc_result):
Expand All @@ -12,3 +12,8 @@ def test_ns(obs_model_prior):
obsconf = obsconfs[0]
fitter = NSFitter(model, prior, obsconf)
fitter.fit(num_samples=5000, num_live_points=200)


def test_prior_predictive_coverage(obs_model_prior):
obsconfs, model, prior = obs_model_prior
BayesianModel(model, prior, obsconfs).prior_predictive_coverage()
3 changes: 3 additions & 0 deletions tests/test_repr.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
def test_gp_bkg(obs_model_prior):
_, model, _ = obs_model_prior
print(repr(model))
Loading