diff --git a/.github/workflows/test-and-coverage.yml b/.github/workflows/test-and-coverage.yml index 214df98..93809b8 100644 --- a/.github/workflows/test-and-coverage.yml +++ b/.github/workflows/test-and-coverage.yml @@ -16,13 +16,11 @@ on: - '.github/workflows/test-and-coverage.yml' - 'docs/**' - permissions: contents: read jobs: tests: - runs-on: ubuntu-latest steps: @@ -38,6 +36,7 @@ jobs: - name: Test with pytest run: docker run -t -v ./:/shared xspec-tests pytest --cov jaxspec --cov-report xml:/shared/coverage.xml - name: "Upload coverage to Codecov" + if: ${{ contains(github.event.pull_request.changed_files, 'src/') || contains(github.event.push.changed_files, 'src/') }} uses: codecov/codecov-action@v3 with: token: ${{ secrets.CODECOV_TOKEN }} diff --git a/docs/examples/background.md b/docs/examples/background.md index 19c095b..da93021 100644 --- a/docs/examples/background.md +++ b/docs/examples/background.md @@ -13,9 +13,9 @@ approach is equivalent to subtract the background to the observed spectrum when from jaxspec.model.background import SubtractedBackground fitter = MCMCFitter(model, prior, obs, background_model=SubtractedBackground()) -result = fitter.fit(num_chains=4, num_warmup=1000, num_samples=1000, mcmc_kwargs={"progress_bar": True}) +result_bkg_substracted = fitter.fit(num_chains=16, num_warmup=1000, num_samples=5000, mcmc_kwargs={"progress_bar": True}) -result.plot_ppc() +result_bkg_substracted.plot_ppc() ``` ![Subtracted background](statics/subtract_background.png) @@ -29,9 +29,9 @@ it is to consider each background bin as a Poisson realisation of a counting pro from jaxspec.model.background import BackgroundWithError fitter = MCMCFitter(model, prior, obs, background_model=BackgroundWithError()) -result = fitter.fit(num_chains=4, num_warmup=1000, num_samples=1000, mcmc_kwargs={"progress_bar": True}) +result_bkg_with_spread = fitter.fit(num_chains=16, num_warmup=1000, num_samples=5000, mcmc_kwargs={"progress_bar": True}) -result.plot_ppc() +result_bkg_with_spread.plot_ppc() ``` ![Subtracted background with errors](statics/subtract_background_with_errors.png) @@ -44,9 +44,25 @@ nodes will drive the flexibility of the Gaussian process, and it should always b from jaxspec.model.background import GaussianProcessBackground forward = MCMCFitter(model, prior, obs, background_model=GaussianProcessBackground(e_min=0.3, e_max=8, n_nodes=20)) -result = forward.fit(num_chains=4, num_warmup=1000, num_samples=1000, mcmc_kwargs={"progress_bar": True}) +result_bkg_gp = forward.fit(num_chains=16, num_warmup=1000, num_samples=5000, mcmc_kwargs={"progress_bar": True}) -result.plot_ppc() +result_bkg_gp.plot_ppc() ``` -![Subtracted background with errors](statics/background_gp.png) \ No newline at end of file +![Subtracted background with errors](statics/background_gp.png) + +We can compare the results for all these background models using the `plot_corner_comparison` function. + +``` python +from jaxspec.analysis.compare import plot_corner_comparison + +plot_corner_comparison( + { + "Background with no spread" : result_bkg_substracted, + "Background with spread" : result_bkg_with_spread, + "Gaussian process background" : result_bkg_gp, + } +) +``` + +![Background comparison](statics/background_comparison.png) diff --git a/docs/examples/statics/background_comparison.png b/docs/examples/statics/background_comparison.png new file mode 100644 index 0000000..31bc302 Binary files /dev/null and b/docs/examples/statics/background_comparison.png differ diff --git a/docs/examples/statics/background_gp.png b/docs/examples/statics/background_gp.png index 447c93b..512f516 100644 Binary files a/docs/examples/statics/background_gp.png and b/docs/examples/statics/background_gp.png differ diff --git a/docs/examples/statics/fitting_example_corner.png b/docs/examples/statics/fitting_example_corner.png index 1a57476..8705f17 100644 Binary files a/docs/examples/statics/fitting_example_corner.png and b/docs/examples/statics/fitting_example_corner.png differ diff --git a/docs/examples/statics/fitting_example_ppc.png b/docs/examples/statics/fitting_example_ppc.png index 53c5e15..f400e1f 100644 Binary files a/docs/examples/statics/fitting_example_ppc.png and b/docs/examples/statics/fitting_example_ppc.png differ diff --git a/docs/examples/statics/subtract_background.png b/docs/examples/statics/subtract_background.png index 0064c3c..27dd5dd 100644 Binary files a/docs/examples/statics/subtract_background.png and b/docs/examples/statics/subtract_background.png differ diff --git a/docs/examples/statics/subtract_background_with_errors.png b/docs/examples/statics/subtract_background_with_errors.png index c9f297f..c673ea6 100644 Binary files a/docs/examples/statics/subtract_background_with_errors.png and b/docs/examples/statics/subtract_background_with_errors.png differ diff --git a/pyproject.toml b/pyproject.toml index d0901ba..68caf95 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "jaxspec" -version = "0.1.3" +version = "0.1.4dev" description = "jaxspec is a bayesian spectral fitting library for X-ray astronomy." authors = ["sdupourque "] license = "MIT" @@ -29,9 +29,9 @@ jaxopt = "^0.8.1" tinygp = "^0.3.0" seaborn = "^0.13.1" sparse = "^0.15.1" -optimistix = ">=0.0.7,<0.0.9" +optimistix = ">=0.0.7,<0.0.10" scipy = "<1.15" -mendeleev = ">=0.15,<0.18" +mendeleev = ">=0.15,<0.19" pyzmq = "<27" jaxns = "<2.6" pooch = "^1.8.2" @@ -57,7 +57,7 @@ testbook = "^0.4.2" [tool.poetry.group.dev.dependencies] -pre-commit = "^3.5.0" +pre-commit = ">=3.5,<5.0" ruff = ">=0.2.1,<0.7.0" jupyterlab = "^4.0.7" notebook = "^7.0.6" diff --git a/src/jaxspec/data/observation.py b/src/jaxspec/data/observation.py index 0be6cb1..4240527 100644 --- a/src/jaxspec/data/observation.py +++ b/src/jaxspec/data/observation.py @@ -1,5 +1,6 @@ import numpy as np import xarray as xr + from .ogip import DataPHA @@ -23,7 +24,16 @@ class Observation(xr.Dataset): folded_background: xr.DataArray """The background counts, after grouping""" - __slots__ = ("grouping", "channel", "quality", "exposure", "background", "folded_background", "counts", "folded_counts") + __slots__ = ( + "grouping", + "channel", + "quality", + "exposure", + "background", + "folded_background", + "counts", + "folded_counts", + ) _default_attributes = {"description": "X-ray observation dataset"} @@ -46,7 +56,11 @@ def from_matrix( background = np.zeros_like(counts, dtype=np.int64) data_dict = { - "counts": (["instrument_channel"], np.asarray(counts, dtype=np.int64), {"description": "Counts", "unit": "photons"}), + "counts": ( + ["instrument_channel"], + np.asarray(counts, dtype=np.int64), + {"description": "Counts", "unit": "photons"}, + ), "folded_counts": ( ["folded_channel"], np.asarray(np.ma.filled(grouping @ counts), dtype=np.int64), @@ -57,7 +71,11 @@ def from_matrix( grouping, {"description": "Grouping matrix."}, ), - "quality": (["instrument_channel"], np.asarray(quality, dtype=np.int64), {"description": "Quality flag."}), + "quality": ( + ["instrument_channel"], + np.asarray(quality, dtype=np.int64), + {"description": "Quality flag."}, + ), "exposure": ([], float(exposure), {"description": "Total exposure", "unit": "s"}), "backratio": ( ["instrument_channel"], @@ -84,20 +102,29 @@ def from_matrix( return cls( data_dict, coords={ - "channel": (["instrument_channel"], np.asarray(channel, dtype=np.int64), {"description": "Channel number"}), + "channel": ( + ["instrument_channel"], + np.asarray(channel, dtype=np.int64), + {"description": "Channel number"}, + ), "grouped_channel": ( ["folded_channel"], np.arange(len(grouping @ counts), dtype=np.int64), {"description": "Channel number"}, ), }, - attrs=cls._default_attributes if attributes is None else attributes | cls._default_attributes, + attrs=cls._default_attributes + if attributes is None + else attributes | cls._default_attributes, ) @classmethod def from_ogip_container(cls, pha: DataPHA, bkg: DataPHA | None = None, **metadata): if bkg is not None: - backratio = np.nan_to_num((pha.backscal * pha.exposure * pha.areascal) / (bkg.backscal * bkg.exposure * bkg.areascal)) + backratio = np.nan_to_num( + (pha.backscal * pha.exposure * pha.areascal) + / (bkg.backscal * bkg.exposure * bkg.areascal) + ) else: backratio = np.ones_like(pha.counts) @@ -114,6 +141,14 @@ def from_ogip_container(cls, pha: DataPHA, bkg: DataPHA | None = None, **metadat @classmethod def from_pha_file(cls, pha_path: str, bkg_path: str | None = None, **metadata): + """ + Build an observation from a PHA file + + Parameters: + pha_path : Path to the PHA file + bkg_path : Path to the background file + metadata : Additional metadata to add to the observation + """ from .util import data_path_finder arf_path, rmf_path, bkg_path_default = data_path_finder(pha_path) @@ -155,7 +190,16 @@ def plot_grouping(self): fig = plt.figure(figsize=(6, 6)) gs = fig.add_gridspec( - 2, 2, width_ratios=(4, 1), height_ratios=(1, 4), left=0.1, right=0.9, bottom=0.1, top=0.9, wspace=0.05, hspace=0.05 + 2, + 2, + width_ratios=(4, 1), + height_ratios=(1, 4), + left=0.1, + right=0.9, + bottom=0.1, + top=0.9, + wspace=0.05, + hspace=0.05, ) ax = fig.add_subplot(gs[1, 0]) ax_histx = fig.add_subplot(gs[0, 0], sharex=ax) diff --git a/src/jaxspec/fit.py b/src/jaxspec/fit.py index df88794..0371390 100644 --- a/src/jaxspec/fit.py +++ b/src/jaxspec/fit.py @@ -104,6 +104,8 @@ class CountForwardModel(hk.Module): A haiku module which allows to build the function that simulates the measured counts """ + # TODO: It has no point of being a haiku module, it should be a simple function + def __init__(self, model: SpectralModel, folding: ObsConfiguration, sparse=False): super().__init__() self.model = model @@ -352,7 +354,9 @@ def fakeit(key, parameters): return fakeit(key, parameters) def prior_predictive_coverage( - self, key: PRNGKey = PRNGKey(0), num_samples: int = 1000, percentiles: tuple = (16, 84) + self, + key: PRNGKey = PRNGKey(0), + num_samples: int = 1000, ): """ Check if the prior distribution include the observed data. @@ -363,24 +367,36 @@ def prior_predictive_coverage( for key, value in self.observation_container.items(): fig, axs = plt.subplots( - nrows=2, ncols=1, sharex=True, figsize=(8, 8), height_ratios=[3, 1] + nrows=2, ncols=1, sharex=True, figsize=(5, 6), height_ratios=[3, 1] ) _plot_poisson_data_with_error( axs[0], value.out_energies, value.folded_counts.values, - percentiles=percentiles, + percentiles=(16, 84), ) - axs[0].stairs( - np.max(posterior_observations["obs_" + key], axis=0), - edges=[*list(value.out_energies[0]), value.out_energies[1][-1]], - baseline=np.min(posterior_observations["obs_" + key], axis=0), - alpha=0.3, - fill=True, - color=(0.15, 0.25, 0.45), - ) + for i, (envelop_percentiles, color, alpha) in enumerate( + zip( + [(16, 86), (2.5, 97.5), (0.15, 99.85)], + ["#03045e", "#0077b6", "#00b4d8"], + [0.5, 0.4, 0.3], + ) + ): + lower, upper = np.percentile( + posterior_observations["obs_" + key], envelop_percentiles, axis=0 + ) + + axs[0].stairs( + upper, + edges=[*list(value.out_energies[0]), value.out_energies[1][-1]], + baseline=lower, + alpha=alpha, + fill=True, + color=color, + label=rf"${1+i}\sigma$", + ) # rank = np.vstack((posterior_observations["obs_" + key], value.folded_counts.values)).argsort(axis=0)[-1] / (num_samples) * 100 counts = posterior_observations["obs_" + key] @@ -408,7 +424,9 @@ def prior_predictive_coverage( axs[1].set_ylim(0, 100) axs[0].set_xlim(value.out_energies.min(), value.out_energies.max()) axs[0].loglog() + axs[0].legend(loc="upper right") plt.suptitle(f"Prior Predictive coverage for {key}") + plt.tight_layout() plt.show() diff --git a/src/jaxspec/model/background.py b/src/jaxspec/model/background.py index bac15f8..91895e5 100644 --- a/src/jaxspec/model/background.py +++ b/src/jaxspec/model/background.py @@ -39,7 +39,7 @@ def numpyro_model(self, obs, spectral_model, name: str = "bkg", observed=True): _, observed_counts = obs.out_energies, obs.folded_background.data numpyro.deterministic(f"{name}", observed_counts) - return jnp.zeros_like(observed_counts) + return observed_counts class BackgroundWithError(BackgroundModel): diff --git a/tests/test_integrate.py b/tests/test_integrate.py index c2a5d97..5fb7ff1 100644 --- a/tests/test_integrate.py +++ b/tests/test_integrate.py @@ -73,13 +73,13 @@ def test_integrate_interval_gradient(): def hyp1f1_integral(a, b, z): def integrand(x, a, b, z): - return jnp.exp(z * x) * x ** (a - 1) * (1 - x) ** (-a + b - 1) + return jnp.exp(z * x) * x ** (a - 1.) * (1. - x) ** (-a + b - 1.) - return integrate_interval(integrand)(0, 1, a, b, z) * gamma(b) / (gamma(a) * gamma(b - a)) + return integrate_interval(integrand)(0., 1., a, b, z) * gamma(b) / (gamma(a) * gamma(b - a)) - a = 1.5 - b = 10.0 - z = 0.5 + a = jnp.asarray(1.5) + b = jnp.asarray(10.0) + z = jnp.asarray(0.5) assert jnp.isclose(jax.grad(hyp1f1_integral)(a, b, z), jax.grad(hyp1f1)(a, b, z)) @@ -90,6 +90,6 @@ def test_integrate_positive_gradient(): """ gamma_integral = integrate_positive(lambda t, z: t ** (z - 1) * jnp.exp(-t)) - z = 2.5 + z = jnp.asarray(2.5) assert jnp.isclose(jax.grad(gamma_integral)(z), jax.grad(gamma)(z))