Skip to content

Commit

Permalink
switching to nnx
Browse files Browse the repository at this point in the history
  • Loading branch information
renecotyfanboy committed Nov 4, 2024
1 parent 7265f4f commit 07bdd3f
Show file tree
Hide file tree
Showing 17 changed files with 833 additions and 1,005 deletions.
Binary file modified docs/examples/statics/fitting_example_corner.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified docs/examples/statics/fitting_example_ppc.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "jaxspec"
version = "0.1.4dev"
version = "0.2.0"
description = "jaxspec is a bayesian spectral fitting library for X-ray astronomy."
authors = ["sdupourque <sdupourque@irap.omp.eu>"]
license = "MIT"
Expand Down Expand Up @@ -28,7 +28,7 @@ gpjax = "^0.8.0"
jaxopt = "^0.8.1"
tinygp = "^0.3.0"
seaborn = "^0.13.1"
sparse = "^0.15.1"
sparse = "^0.15.4"
optimistix = ">=0.0.7,<0.0.10"
scipy = "<1.15"
mendeleev = ">=0.15,<0.19"
Expand Down
60 changes: 30 additions & 30 deletions src/jaxspec/_fit/_build_model.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,20 @@
import jax
import numpyro
from collections.abc import Callable
from typing import TYPE_CHECKING

import haiku as hk
import numpy as np
import jax
import jax.numpy as jnp
from typing import Callable
import numpy as np
import numpyro

from jax.experimental.sparse import BCOO
from typing import TYPE_CHECKING
from numpyro.distributions import Poisson
from jax.typing import ArrayLike
from numpyro.distributions import Distribution

from numpyro.distributions import Distribution, Poisson

if TYPE_CHECKING:
from ..model.abc import SpectralModel
from ..data import ObsConfiguration
from ..util.typing import PriorDictModel, PriorDictType

from ..model.abc import SpectralModel
from ..util.typing import PriorDictType


class CountForwardModel(hk.Module):
Expand All @@ -25,7 +24,7 @@ class CountForwardModel(hk.Module):

# TODO: It has no point of being a haiku module, it should be a simple function

def __init__(self, model: 'SpectralModel', folding: 'ObsConfiguration', sparse=False):
def __init__(self, model: "SpectralModel", folding: "ObsConfiguration", sparse=False):
super().__init__()
self.model = model
self.energies = jnp.asarray(folding.in_energies)
Expand All @@ -51,12 +50,11 @@ def __call__(self, parameters):


def forward_model(
model: 'SpectralModel',
parameters,
obs_configuration: 'ObsConfiguration',
sparse=False,
):

model: "SpectralModel",
parameters,
obs_configuration: "ObsConfiguration",
sparse=False,
):
energies = np.asarray(obs_configuration.in_energies)

if sparse:
Expand Down Expand Up @@ -86,7 +84,6 @@ def build_numpyro_model_for_single_obs(
"""

def numpyro_model(prior_params, observed=True):

# Return the expected countrate for a set of parameters
obs_model = jax.jit(lambda par: forward_model(model, par, obs, sparse=sparse))
countrate = obs_model(prior_params)
Expand All @@ -105,7 +102,6 @@ def numpyro_model(prior_params, observed=True):
else:
bkg_countrate = 0.0


# Register the observed value
# This is the case where we fit a model to a TOTAL spectrum as defined in OGIP standard
with numpyro.plate("obs_plate_" + name, len(obs.folded_counts)):
Expand All @@ -118,23 +114,27 @@ def numpyro_model(prior_params, observed=True):
return numpyro_model


def build_prior(prior: 'PriorDictType', expand_shape: tuple = (), prefix=""):
def build_prior(prior: "PriorDictType", expand_shape: tuple = (), prefix=""):
"""
Transform a dictionary of prior distributions into a dictionary of parameters sampled from the prior.
Must be used within a numpyro model.
"""
parameters = dict(hk.data_structures.to_haiku_dict(prior))

for i, (m, n, sample) in enumerate(hk.data_structures.traverse(prior)):
if isinstance(sample, Distribution):
parameters[m][n] = jnp.ones(expand_shape) * numpyro.sample(f"{prefix}{m}_{n}", sample)
parameters = {}

for key, value in prior.items():
# Split the key to extract the module name and parameter name
module_name, param_name = key.rsplit("_", 1)
if isinstance(value, Distribution):
parameters[key] = jnp.ones(expand_shape) * numpyro.sample(
f"{prefix}{module_name}_{param_name}", value
)

elif isinstance(sample, ArrayLike):
parameters[m][n] = jnp.ones(expand_shape) * sample
elif isinstance(value, ArrayLike):
parameters[key] = jnp.ones(expand_shape) * value

else:
raise ValueError(
f"Invalid prior type {type(sample)} for parameter {prefix}{m}_{n} : {sample}"
f"Invalid prior type {type(value)} for parameter {prefix}{module_name}_{param_name} : {value}"
)

return parameters
return parameters
21 changes: 11 additions & 10 deletions src/jaxspec/analysis/results.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from __future__ import annotations

from collections.abc import Mapping
from typing import TYPE_CHECKING, Any, Literal, TypeVar

import arviz as az
Expand Down Expand Up @@ -102,15 +101,13 @@ def __init__(
self,
bayesian_fitter: BayesianModel,
inference_data: az.InferenceData,
structure: Mapping[K, V],
background_model: BackgroundModel = None,
):
self.model = bayesian_fitter.model
self.bayesian_fitter = bayesian_fitter
self.inference_data = inference_data
self.obsconfs = bayesian_fitter.observation_container
self.background_model = background_model
self._structure = structure

# Add the model used in fit to the metadata
for group in self.inference_data.groups():
Expand Down Expand Up @@ -173,27 +170,31 @@ def input_parameters(self) -> HaikuDict[ArrayLike]:
with seed(rng_seed=0):
input_parameters = self.bayesian_fitter.prior_distributions_func()

for module, parameter, value in traverse(input_parameters):
for key, value in input_parameters.items():
module, parameter = key.rsplit("_", 1)

if f"{module}_{parameter}" in posterior.keys():
# We add as extra dimension as there might be different values per observation
if posterior[f"{module}_{parameter}"].shape == samples_shape:
to_set = posterior[f"{module}_{parameter}"][..., None]
else:
to_set = posterior[f"{module}_{parameter}"]

input_parameters[module][parameter] = to_set
input_parameters[f"{module}_{parameter}"] = to_set

else:
# The parameter is fixed in this case, so we just broadcast is over chain and draws
input_parameters[module][parameter] = value[None, None, ...]
input_parameters[f"{module}_{parameter}"] = value[None, None, ...]

if len(total_shape) < len(input_parameters[module][parameter].shape):
if len(total_shape) < len(input_parameters[f"{module}_{parameter}"].shape):
# If there are only chains and draws, we reduce
input_parameters[module][parameter] = input_parameters[module][parameter][..., 0]
input_parameters[f"{module}_{parameter}"] = input_parameters[
f"{module}_{parameter}"
][..., 0]

else:
input_parameters[module][parameter] = jnp.broadcast_to(
input_parameters[module][parameter], total_shape
input_parameters[f"{module}_{parameter}"] = jnp.broadcast_to(
input_parameters[f"{module}_{parameter}"], total_shape
)

return input_parameters
Expand Down
80 changes: 5 additions & 75 deletions src/jaxspec/data/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,13 @@
from pathlib import Path
from typing import Literal, TypeVar

import haiku as hk
import jax
import numpy as np
import numpyro

from astropy.io import fits
from numpy.typing import ArrayLike
from numpyro import handlers

from .._fit._build_model import CountForwardModel
from .._fit._build_model import forward_model
from ..model.abc import SpectralModel
from ..util.online_storage import table_manager
from . import Instrument, ObsConfiguration, Observation
Expand Down Expand Up @@ -127,67 +124,6 @@ def load_example_obsconf(source: Literal["NGC7793_ULX4_PN", "NGC7793_ULX4_ALL"])
raise ValueError(f"{source} not recognized.")


def fakeit(
instrument: ObsConfiguration | list[ObsConfiguration],
model: SpectralModel,
parameters: Mapping[K, V],
rng_key: int = 0,
sparsify_matrix: bool = False,
) -> ArrayLike | list[ArrayLike]:
"""
Convenience function to simulate a spectrum from a given model and a set of parameters.
It requires an instrumental setup, and unlike in
[XSPEC's fakeit](https://heasarc.gsfc.nasa.gov/xanadu/xspec/manual/node72.html), the error on the counts is given
exclusively by Poisson statistics.
Parameters:
instrument: The instrumental setup.
model: The model to use.
parameters: The parameters of the model.
rng_key: The random number generator seed.
sparsify_matrix: Whether to sparsify the matrix or not.
"""

instruments = [instrument] if isinstance(instrument, ObsConfiguration) else instrument
fakeits = []

for i, instrument in enumerate(instruments):
transformed_model = hk.without_apply_rng(
hk.transform(
lambda par: CountForwardModel(model, instrument, sparse=sparsify_matrix)(par)
)
)

def obs_model(p):
return transformed_model.apply(None, p)

with handlers.seed(rng_seed=rng_key):
counts = numpyro.sample(
f"likelihood_obs_{i}",
numpyro.distributions.Poisson(obs_model(parameters)),
)

"""
pha = DataPHA(
instrument.rmf.channel,
np.array(counts, dtype=int)*u.ct,
instrument.exposure,
grouping=instrument.grouping)
observation = Observation(
pha=pha,
arf=instrument.arf,
rmf=instrument.rmf,
low_energy=instrument.low_energy,
high_energy=instrument.high_energy
)
"""

fakeits.append(np.array(counts, dtype=int))

return fakeits[0] if len(fakeits) == 1 else fakeits


def fakeit_for_multiple_parameters(
instrument: ObsConfiguration | list[ObsConfiguration],
model: SpectralModel,
Expand All @@ -199,7 +135,6 @@ def fakeit_for_multiple_parameters(
"""
Convenience function to simulate multiple spectra from a given model and a set of parameters.
TODO : avoid redundancy, better doc and type hints
Parameters:
instrument: The instrumental setup.
Expand All @@ -214,24 +149,19 @@ def fakeit_for_multiple_parameters(
fakeits = []

for i, obs in enumerate(instruments):
transformed_model = hk.without_apply_rng(
hk.transform(lambda par: CountForwardModel(model, obs, sparse=sparsify_matrix)(par))
countrate = jax.vmap(lambda p: forward_model(model, p, instrument, sparse=sparsify_matrix))(
parameters
)

@jax.jit
@jax.vmap
def obs_model(p):
return transformed_model.apply(None, p)

if apply_stat:
with handlers.seed(rng_seed=rng_key):
spectrum = numpyro.sample(
f"likelihood_obs_{i}",
numpyro.distributions.Poisson(obs_model(parameters)),
numpyro.distributions.Poisson(countrate),
)

else:
spectrum = obs_model(parameters)
spectrum = countrate

fakeits.append(spectrum)

Expand Down
10 changes: 5 additions & 5 deletions src/jaxspec/fit.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from .data import ObsConfiguration
from .model.abc import SpectralModel
from .model.background import BackgroundModel
from .util.typing import PriorDictModel, PriorDictType
from .util.typing import PriorDictType


class BayesianModel:
Expand Down Expand Up @@ -63,10 +63,12 @@ def __init__(

if not callable(prior_distributions):
# Validate the entry with pydantic
prior = PriorDictModel.from_dict(prior_distributions).nested_dict
# prior = PriorDictModel.from_dict(prior_distributions).

def prior_distributions_func():
return build_prior(prior, expand_shape=(len(self.observation_container),))
return build_prior(
prior_distributions, expand_shape=(len(self.observation_container),)
)

else:
prior_distributions_func = prior_distributions
Expand Down Expand Up @@ -544,7 +546,6 @@ def fit(
return FitResult(
self,
inference_data,
self.model.params,
background_model=self.background_model,
)

Expand Down Expand Up @@ -613,6 +614,5 @@ def fit(
return FitResult(
self,
inference_data,
self.model.params,
background_model=self.background_model,
)
Loading

0 comments on commit 07bdd3f

Please sign in to comment.