Skip to content

Commit

Permalink
Merge branch 'main' into dependabot/pip/simpleeval-gte-0.9.13-and-lt-…
Browse files Browse the repository at this point in the history
…1.1.0
  • Loading branch information
renecotyfanboy authored Oct 25, 2024
2 parents db2a730 + fa08469 commit 3e461b8
Show file tree
Hide file tree
Showing 13 changed files with 115 additions and 38 deletions.
3 changes: 1 addition & 2 deletions .github/workflows/test-and-coverage.yml
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,11 @@ on:
- '.github/workflows/test-and-coverage.yml'
- 'docs/**'


permissions:
contents: read

jobs:
tests:

runs-on: ubuntu-latest

steps:
Expand All @@ -38,6 +36,7 @@ jobs:
- name: Test with pytest
run: docker run -t -v ./:/shared xspec-tests pytest --cov jaxspec --cov-report xml:/shared/coverage.xml
- name: "Upload coverage to Codecov"
if: ${{ contains(github.event.pull_request.changed_files, 'src/') || contains(github.event.push.changed_files, 'src/') }}
uses: codecov/codecov-action@v3
with:
token: ${{ secrets.CODECOV_TOKEN }}
Expand Down
30 changes: 23 additions & 7 deletions docs/examples/background.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,9 @@ approach is equivalent to subtract the background to the observed spectrum when
from jaxspec.model.background import SubtractedBackground

fitter = MCMCFitter(model, prior, obs, background_model=SubtractedBackground())
result = fitter.fit(num_chains=4, num_warmup=1000, num_samples=1000, mcmc_kwargs={"progress_bar": True})
result_bkg_substracted = fitter.fit(num_chains=16, num_warmup=1000, num_samples=5000, mcmc_kwargs={"progress_bar": True})

result.plot_ppc()
result_bkg_substracted.plot_ppc()
```

![Subtracted background](statics/subtract_background.png)
Expand All @@ -29,9 +29,9 @@ it is to consider each background bin as a Poisson realisation of a counting pro
from jaxspec.model.background import BackgroundWithError

fitter = MCMCFitter(model, prior, obs, background_model=BackgroundWithError())
result = fitter.fit(num_chains=4, num_warmup=1000, num_samples=1000, mcmc_kwargs={"progress_bar": True})
result_bkg_with_spread = fitter.fit(num_chains=16, num_warmup=1000, num_samples=5000, mcmc_kwargs={"progress_bar": True})

result.plot_ppc()
result_bkg_with_spread.plot_ppc()
```

![Subtracted background with errors](statics/subtract_background_with_errors.png)
Expand All @@ -44,9 +44,25 @@ nodes will drive the flexibility of the Gaussian process, and it should always b
from jaxspec.model.background import GaussianProcessBackground

forward = MCMCFitter(model, prior, obs, background_model=GaussianProcessBackground(e_min=0.3, e_max=8, n_nodes=20))
result = forward.fit(num_chains=4, num_warmup=1000, num_samples=1000, mcmc_kwargs={"progress_bar": True})
result_bkg_gp = forward.fit(num_chains=16, num_warmup=1000, num_samples=5000, mcmc_kwargs={"progress_bar": True})

result.plot_ppc()
result_bkg_gp.plot_ppc()
```

![Subtracted background with errors](statics/background_gp.png)
![Subtracted background with errors](statics/background_gp.png)

We can compare the results for all these background models using the `plot_corner_comparison` function.

``` python
from jaxspec.analysis.compare import plot_corner_comparison

plot_corner_comparison(
{
"Background with no spread" : result_bkg_substracted,
"Background with spread" : result_bkg_with_spread,
"Gaussian process background" : result_bkg_gp,
}
)
```

![Background comparison](statics/background_comparison.png)
Binary file added docs/examples/statics/background_comparison.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified docs/examples/statics/background_gp.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified docs/examples/statics/fitting_example_corner.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified docs/examples/statics/fitting_example_ppc.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified docs/examples/statics/subtract_background.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified docs/examples/statics/subtract_background_with_errors.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
8 changes: 4 additions & 4 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "jaxspec"
version = "0.1.3"
version = "0.1.4dev"
description = "jaxspec is a bayesian spectral fitting library for X-ray astronomy."
authors = ["sdupourque <sdupourque@irap.omp.eu>"]
license = "MIT"
Expand Down Expand Up @@ -29,9 +29,9 @@ jaxopt = "^0.8.1"
tinygp = "^0.3.0"
seaborn = "^0.13.1"
sparse = "^0.15.1"
optimistix = ">=0.0.7,<0.0.9"
optimistix = ">=0.0.7,<0.0.10"
scipy = "<1.15"
mendeleev = ">=0.15,<0.18"
mendeleev = ">=0.15,<0.19"
pyzmq = "<27"
jaxns = "<2.6"
pooch = "^1.8.2"
Expand All @@ -57,7 +57,7 @@ testbook = "^0.4.2"


[tool.poetry.group.dev.dependencies]
pre-commit = "^3.5.0"
pre-commit = ">=3.5,<5.0"
ruff = ">=0.2.1,<0.7.0"
jupyterlab = "^4.0.7"
notebook = "^7.0.6"
Expand Down
58 changes: 51 additions & 7 deletions src/jaxspec/data/observation.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import numpy as np
import xarray as xr

from .ogip import DataPHA


Expand All @@ -23,7 +24,16 @@ class Observation(xr.Dataset):
folded_background: xr.DataArray
"""The background counts, after grouping"""

__slots__ = ("grouping", "channel", "quality", "exposure", "background", "folded_background", "counts", "folded_counts")
__slots__ = (
"grouping",
"channel",
"quality",
"exposure",
"background",
"folded_background",
"counts",
"folded_counts",
)

_default_attributes = {"description": "X-ray observation dataset"}

Expand All @@ -46,7 +56,11 @@ def from_matrix(
background = np.zeros_like(counts, dtype=np.int64)

data_dict = {
"counts": (["instrument_channel"], np.asarray(counts, dtype=np.int64), {"description": "Counts", "unit": "photons"}),
"counts": (
["instrument_channel"],
np.asarray(counts, dtype=np.int64),
{"description": "Counts", "unit": "photons"},
),
"folded_counts": (
["folded_channel"],
np.asarray(np.ma.filled(grouping @ counts), dtype=np.int64),
Expand All @@ -57,7 +71,11 @@ def from_matrix(
grouping,
{"description": "Grouping matrix."},
),
"quality": (["instrument_channel"], np.asarray(quality, dtype=np.int64), {"description": "Quality flag."}),
"quality": (
["instrument_channel"],
np.asarray(quality, dtype=np.int64),
{"description": "Quality flag."},
),
"exposure": ([], float(exposure), {"description": "Total exposure", "unit": "s"}),
"backratio": (
["instrument_channel"],
Expand All @@ -84,20 +102,29 @@ def from_matrix(
return cls(
data_dict,
coords={
"channel": (["instrument_channel"], np.asarray(channel, dtype=np.int64), {"description": "Channel number"}),
"channel": (
["instrument_channel"],
np.asarray(channel, dtype=np.int64),
{"description": "Channel number"},
),
"grouped_channel": (
["folded_channel"],
np.arange(len(grouping @ counts), dtype=np.int64),
{"description": "Channel number"},
),
},
attrs=cls._default_attributes if attributes is None else attributes | cls._default_attributes,
attrs=cls._default_attributes
if attributes is None
else attributes | cls._default_attributes,
)

@classmethod
def from_ogip_container(cls, pha: DataPHA, bkg: DataPHA | None = None, **metadata):
if bkg is not None:
backratio = np.nan_to_num((pha.backscal * pha.exposure * pha.areascal) / (bkg.backscal * bkg.exposure * bkg.areascal))
backratio = np.nan_to_num(
(pha.backscal * pha.exposure * pha.areascal)
/ (bkg.backscal * bkg.exposure * bkg.areascal)
)
else:
backratio = np.ones_like(pha.counts)

Expand All @@ -114,6 +141,14 @@ def from_ogip_container(cls, pha: DataPHA, bkg: DataPHA | None = None, **metadat

@classmethod
def from_pha_file(cls, pha_path: str, bkg_path: str | None = None, **metadata):
"""
Build an observation from a PHA file
Parameters:
pha_path : Path to the PHA file
bkg_path : Path to the background file
metadata : Additional metadata to add to the observation
"""
from .util import data_path_finder

arf_path, rmf_path, bkg_path_default = data_path_finder(pha_path)
Expand Down Expand Up @@ -155,7 +190,16 @@ def plot_grouping(self):

fig = plt.figure(figsize=(6, 6))
gs = fig.add_gridspec(
2, 2, width_ratios=(4, 1), height_ratios=(1, 4), left=0.1, right=0.9, bottom=0.1, top=0.9, wspace=0.05, hspace=0.05
2,
2,
width_ratios=(4, 1),
height_ratios=(1, 4),
left=0.1,
right=0.9,
bottom=0.1,
top=0.9,
wspace=0.05,
hspace=0.05,
)
ax = fig.add_subplot(gs[1, 0])
ax_histx = fig.add_subplot(gs[0, 0], sharex=ax)
Expand Down
40 changes: 29 additions & 11 deletions src/jaxspec/fit.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,8 @@ class CountForwardModel(hk.Module):
A haiku module which allows to build the function that simulates the measured counts
"""

# TODO: It has no point of being a haiku module, it should be a simple function

def __init__(self, model: SpectralModel, folding: ObsConfiguration, sparse=False):
super().__init__()
self.model = model
Expand Down Expand Up @@ -352,7 +354,9 @@ def fakeit(key, parameters):
return fakeit(key, parameters)

def prior_predictive_coverage(
self, key: PRNGKey = PRNGKey(0), num_samples: int = 1000, percentiles: tuple = (16, 84)
self,
key: PRNGKey = PRNGKey(0),
num_samples: int = 1000,
):
"""
Check if the prior distribution include the observed data.
Expand All @@ -363,24 +367,36 @@ def prior_predictive_coverage(

for key, value in self.observation_container.items():
fig, axs = plt.subplots(
nrows=2, ncols=1, sharex=True, figsize=(8, 8), height_ratios=[3, 1]
nrows=2, ncols=1, sharex=True, figsize=(5, 6), height_ratios=[3, 1]
)

_plot_poisson_data_with_error(
axs[0],
value.out_energies,
value.folded_counts.values,
percentiles=percentiles,
percentiles=(16, 84),
)

axs[0].stairs(
np.max(posterior_observations["obs_" + key], axis=0),
edges=[*list(value.out_energies[0]), value.out_energies[1][-1]],
baseline=np.min(posterior_observations["obs_" + key], axis=0),
alpha=0.3,
fill=True,
color=(0.15, 0.25, 0.45),
)
for i, (envelop_percentiles, color, alpha) in enumerate(
zip(
[(16, 86), (2.5, 97.5), (0.15, 99.85)],
["#03045e", "#0077b6", "#00b4d8"],
[0.5, 0.4, 0.3],
)
):
lower, upper = np.percentile(
posterior_observations["obs_" + key], envelop_percentiles, axis=0
)

axs[0].stairs(
upper,
edges=[*list(value.out_energies[0]), value.out_energies[1][-1]],
baseline=lower,
alpha=alpha,
fill=True,
color=color,
label=rf"${1+i}\sigma$",
)

# rank = np.vstack((posterior_observations["obs_" + key], value.folded_counts.values)).argsort(axis=0)[-1] / (num_samples) * 100
counts = posterior_observations["obs_" + key]
Expand Down Expand Up @@ -408,7 +424,9 @@ def prior_predictive_coverage(
axs[1].set_ylim(0, 100)
axs[0].set_xlim(value.out_energies.min(), value.out_energies.max())
axs[0].loglog()
axs[0].legend(loc="upper right")
plt.suptitle(f"Prior Predictive coverage for {key}")
plt.tight_layout()
plt.show()


Expand Down
2 changes: 1 addition & 1 deletion src/jaxspec/model/background.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def numpyro_model(self, obs, spectral_model, name: str = "bkg", observed=True):
_, observed_counts = obs.out_energies, obs.folded_background.data
numpyro.deterministic(f"{name}", observed_counts)

return jnp.zeros_like(observed_counts)
return observed_counts


class BackgroundWithError(BackgroundModel):
Expand Down
12 changes: 6 additions & 6 deletions tests/test_integrate.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,13 +73,13 @@ def test_integrate_interval_gradient():

def hyp1f1_integral(a, b, z):
def integrand(x, a, b, z):
return jnp.exp(z * x) * x ** (a - 1) * (1 - x) ** (-a + b - 1)
return jnp.exp(z * x) * x ** (a - 1.) * (1. - x) ** (-a + b - 1.)

return integrate_interval(integrand)(0, 1, a, b, z) * gamma(b) / (gamma(a) * gamma(b - a))
return integrate_interval(integrand)(0., 1., a, b, z) * gamma(b) / (gamma(a) * gamma(b - a))

a = 1.5
b = 10.0
z = 0.5
a = jnp.asarray(1.5)
b = jnp.asarray(10.0)
z = jnp.asarray(0.5)

assert jnp.isclose(jax.grad(hyp1f1_integral)(a, b, z), jax.grad(hyp1f1)(a, b, z))

Expand All @@ -90,6 +90,6 @@ def test_integrate_positive_gradient():
"""

gamma_integral = integrate_positive(lambda t, z: t ** (z - 1) * jnp.exp(-t))
z = 2.5
z = jnp.asarray(2.5)

assert jnp.isclose(jax.grad(gamma_integral)(z), jax.grad(gamma)(z))

0 comments on commit 3e461b8

Please sign in to comment.