Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Speedup data reading #139

Merged
merged 5 commits into from
Apr 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "jaxspec"
version = "0.0.3"
version = "0.0.4"
description = "jaxspec is a bayesian spectral fitting library for X-ray astronomy."
authors = ["sdupourque <sdupourque@irap.omp.eu>"]
license = "MIT"
Expand All @@ -17,7 +17,7 @@ numpyro = ">=0.13.2,<0.15.0"
dm-haiku = ">=0.0.11,<0.0.13"
networkx = "^3.1"
matplotlib = "^3.8.0"
arviz = "^0.17.0"
arviz = "^0.17.1"
chainconsumer = "^1.0.0"
simpleeval = "^0.9.13"
cmasher = "^1.6.3"
Expand All @@ -27,6 +27,7 @@ tinygp = "^0.3.0"
seaborn = "^0.13.1"
mkdocstrings = "^0.24.0"
sparse = "^0.15.1"
scipy = "<1.13"


[tool.poetry.group.docs.dependencies]
Expand Down
172 changes: 101 additions & 71 deletions src/jaxspec/data/obsconf.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,11 @@
import numpy as np
import xarray as xr
import sparse
import scipy
from .instrument import Instrument
from .observation import Observation


def densify_xarray(xarray):
return xr.DataArray(xarray.data.todense(), dims=xarray.dims, coords=xarray.coords, attrs=xarray.attrs, name=xarray.name)


class ObsConfiguration(xr.Dataset):
"""
Class to store the data of a folding model, which is the link between the unfolded and folded spectra.
Expand Down Expand Up @@ -56,8 +53,8 @@

out_energies = np.stack(
(
np.asarray(self.coords["e_min_folded"].data.todense(), dtype=np.float64),
np.asarray(self.coords["e_max_folded"].data.todense(), dtype=np.float64),
np.asarray(self.coords["e_min_folded"].data, dtype=np.float64),
np.asarray(self.coords["e_max_folded"].data, dtype=np.float64),
)
)

Expand All @@ -67,92 +64,125 @@
def from_pha_file(
cls, pha_path, rmf_path=None, arf_path=None, bkg_path=None, low_energy: float = 1e-20, high_energy: float = 1e20
):
from .util import data_loader
from .util import data_path_finder

pha, arf, rmf, bkg, metadata = data_loader(pha_path, arf_path=arf_path, rmf_path=rmf_path, bkg_path=bkg_path)
arf_path_default, rmf_path_default, bkg_path_default = data_path_finder(pha_path)

instrument = Instrument.from_matrix(
rmf.sparse_matrix,
arf.specresp if arf is not None else np.ones_like(rmf.energ_lo),
rmf.energ_lo,
rmf.energ_hi,
rmf.e_min,
rmf.e_max,
)
arf_path = arf_path_default if arf_path is None else arf_path
rmf_path = rmf_path_default if rmf_path is None else rmf_path
bkg_path = bkg_path_default if bkg_path is None else bkg_path

if bkg is not None:
backratio = np.where(bkg.backscal > 0.0, pha.backscal / np.where(bkg.backscal > 0, bkg.backscal, 1.0), 0.0)
else:
backratio = np.ones_like(pha.counts)

observation = Observation.from_matrix(
pha.counts,
pha.grouping,
pha.channel,
pha.quality,
pha.exposure,
background=bkg.counts if bkg is not None else None,
backratio=backratio,
attributes=metadata,
)
instrument = Instrument.from_ogip_file(rmf_path, arf_path=arf_path)
observation = Observation.from_pha_file(pha_path, bkg_path=bkg_path)

return cls.from_instrument(instrument, observation, low_energy=low_energy, high_energy=high_energy)

@classmethod
def from_instrument(
cls, instrument: Instrument, observation: Observation, low_energy: float = 1e-20, high_energy: float = 1e20
):
# Exclude the bins flagged with bad quality
quality_filter = observation.quality == 0
grouping = observation.grouping * quality_filter
# First we unpack all the xarray data to classical np array for efficiency
# We also exclude the bins that are flagged with bad quality on the instrument
quality_filter = observation.quality.data == 0
grouping = scipy.sparse.csr_array(observation.grouping.data.to_scipy_sparse()) * quality_filter
e_min_channel = instrument.coords["e_min_channel"].data
e_max_channel = instrument.coords["e_max_channel"].data
e_min_unfolded = instrument.coords["e_min_unfolded"].data
e_max_unfolded = instrument.coords["e_max_unfolded"].data
redistribution = scipy.sparse.csr_array(instrument.redistribution.data.to_scipy_sparse())
area = instrument.area.data
exposure = observation.exposure.data

# Computing the lower and upper energies of the bins after grouping
# This is just a trick to compute it without 10 lines of code
e_min = (xr.where(grouping > 0, grouping, np.nan) * instrument.coords["e_min_channel"]).min(
skipna=True, dim="instrument_channel"
)
grouping_nan = observation.grouping.data * quality_filter
grouping_nan.fill_value = np.nan
e_min = sparse.nanmin(grouping_nan * e_min_channel, axis=1).todense()
e_max = sparse.nanmax(grouping_nan * e_max_channel, axis=1).todense()

e_max = (xr.where(grouping > 0, grouping, np.nan) * instrument.coords["e_max_channel"]).max(
skipna=True, dim="instrument_channel"
)

transfer_matrix = grouping @ (instrument.redistribution * instrument.area * observation.exposure)
transfer_matrix = transfer_matrix.assign_coords({"e_min_folded": e_min, "e_max_folded": e_max})
# Compute the transfer matrix
transfer_matrix = grouping @ (redistribution * area * exposure)

# Exclude bins out of the considered energy range, and bins without contribution from the RMF
row_idx = densify_xarray(((e_min > low_energy) & (e_max < high_energy)) * (grouping.sum(dim="instrument_channel") > 0))

col_idx = densify_xarray(
(instrument.coords["e_min_unfolded"] > 0) * (instrument.redistribution.sum(dim="instrument_channel") > 0)
)

# The transfer matrix is converted locally to csr format to allow FAST slicing
transfer_matrix_scipy = transfer_matrix.data.to_scipy_sparse().tocsr()
transfer_matrix_reduced = transfer_matrix_scipy[row_idx.data][:, col_idx.data]
transfer_matrix_reduced = sparse.COO.from_scipy_sparse(transfer_matrix_reduced)
row_idx = (e_min > low_energy) & (e_max < high_energy) & (grouping.sum(axis=1) > 0)
col_idx = (e_min_unfolded > 0) & (redistribution.sum(axis=0) > 0)

# A dummy zero matrix is put so that the slicing in xarray is fast
transfer_matrix.data = sparse.zeros_like(transfer_matrix.data)
transfer_matrix = transfer_matrix[row_idx][:, col_idx]

# The reduced transfer matrix is put back in the xarray
transfer_matrix.data = transfer_matrix_reduced

folded_counts = observation.folded_counts.copy().where(row_idx, drop=True)
# Apply this reduction to all the relevant arrays
transfer_matrix = sparse.COO.from_scipy_sparse(transfer_matrix[row_idx][:, col_idx])
folded_counts = observation.folded_counts.data[row_idx]
folded_backratio = observation.folded_backratio.data[row_idx]
area = instrument.area.data[col_idx]
e_min_folded = e_min[row_idx]
e_max_folded = e_max[row_idx]
e_min_unfolded = e_min_unfolded[col_idx]
e_max_unfolded = e_max_unfolded[col_idx]

if observation.folded_background is not None:
folded_background = observation.folded_background.copy().where(row_idx, drop=True)

folded_background = observation.folded_background.data[row_idx]
else:
folded_background = None
folded_background = np.zeros_like(folded_counts)

Check warning on line 124 in src/jaxspec/data/obsconf.py

View check run for this annotation

Codecov / codecov/patch

src/jaxspec/data/obsconf.py#L124

Added line #L124 was not covered by tests

data_dict = {
"transfer_matrix": (
["folded_channel", "unfolded_channel"],
transfer_matrix,
{
"description": "Transfer matrix to use to fold the incoming spectrum. It is built and restricted using the grouping, redistribution matrix, effective area, quality flags and energy bands defined by the user."
},
),
"area": (
["unfolded_channel"],
area,
{"description": "Effective area with the same restrictions as the transfer matrix.", "units": "cm^2"},
),
"exposure": ([], exposure, {"description": "Total exposure", "unit": "s"}),
"folded_counts": (
["folded_channel"],
folded_counts,
{
"description": "Folded counts after grouping, with the same restrictions as the transfer matrix.",
"unit": "photons",
},
),
"folded_backratio": (
["folded_channel"],
folded_backratio,
{"description": "Background scaling after grouping, with the same restrictions as the transfer matrix."},
),
"folded_background": (
["folded_channel"],
folded_background,
{
"description": "Folded background counts after grouping, with the same restrictions as the transfer matrix.",
"unit": "photons",
},
),
}

return cls(
{
"transfer_matrix": transfer_matrix,
"area": instrument.area.copy().where(col_idx, drop=True),
"exposure": observation.exposure,
"folded_backratio": observation.folded_backratio.copy().where(row_idx, drop=True),
"folded_counts": folded_counts,
"folded_background": folded_background,
}
data_dict,
coords={
"e_min_folded": (
["folded_channel"],
e_min_folded,
{"description": "Low energy of folded channel"},
),
"e_max_folded": (
["folded_channel"],
e_max_folded,
{"description": "High energy of folded channel"},
),
"e_min_unfolded": (
["unfolded_channel"],
e_min_unfolded,
{"description": "Low energy of unfolded channel"},
),
"e_max_unfolded": (
["unfolded_channel"],
e_max_unfolded,
{"description": "High energy of unfolded channel"},
),
},
attrs=observation.attrs | instrument.attrs,
)
30 changes: 24 additions & 6 deletions src/jaxspec/data/observation.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import os
import numpy as np
import xarray as xr
from .ogip import DataPHA


class Observation(xr.Dataset):
Expand Down Expand Up @@ -95,11 +95,7 @@
)

@classmethod
def from_pha_file(cls, pha_file: str | os.PathLike, **kwargs):
from .util import data_loader

pha, arf, rmf, bkg, metadata = data_loader(pha_file)

def from_ogip_container(cls, pha: DataPHA, bkg: DataPHA | None = None, **metadata):
if bkg is not None:
backratio = np.nan_to_num((pha.backscal * pha.exposure * pha.areascal) / (bkg.backscal * bkg.exposure * bkg.areascal))
else:
Expand All @@ -116,6 +112,28 @@
attributes=metadata,
)

@classmethod
def from_pha_file(cls, pha_path: str, bkg_path: str | None = None, **metadata):
from .util import data_path_finder

arf_path, rmf_path, bkg_path_default = data_path_finder(pha_path)
bkg_path = bkg_path_default if bkg_path is None else bkg_path

pha = DataPHA.from_file(pha_path)
bkg = DataPHA.from_file(bkg_path) if bkg_path is not None else None

if metadata is None:
metadata = {}

Check warning on line 126 in src/jaxspec/data/observation.py

View check run for this annotation

Codecov / codecov/patch

src/jaxspec/data/observation.py#L126

Added line #L126 was not covered by tests

metadata.update(
observation_file=pha_path,
background_file=bkg_path,
response_matrix_file=rmf_path,
ancillary_response_file=arf_path,
)

return cls.from_ogip_container(pha, bkg=bkg, **metadata)

def plot_counts(self, **kwargs):
"""
Plot the counts
Expand Down
48 changes: 17 additions & 31 deletions src/jaxspec/data/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@
from pathlib import Path
from numpy.typing import ArrayLike
from collections.abc import Mapping
from typing import TypeVar
from typing import TypeVar, Tuple
from astropy.io import fits

from .ogip import DataPHA, DataARF, DataRMF
from . import Observation, Instrument, ObsConfiguration
from ..model.abc import SpectralModel
from ..fit import CountForwardModel
Expand Down Expand Up @@ -201,46 +201,32 @@ def obs_model(p):
return fakeits[0] if len(fakeits) == 1 else fakeits


def data_loader(pha_path: str, arf_path=None, rmf_path=None, bkg_path=None):
def data_path_finder(pha_path: str) -> Tuple[str | None, str | None, str | None]:
"""
This function is a convenience function that allows to load PHA, ARF and RMF data
from a given PHA file, using either the ARF/RMF/BKG filenames in the header or the
specified filenames overwritten by the user.

This function tries its best to find the ARF, RMF and BKG files associated with a given PHA file.
Parameters:
pha_path: The PHA file path.

Returns:
arf_path: The ARF file path.
rmf_path: The RMF file path.
bkg_path: The BKG file path.
"""

pha = DataPHA.from_file(pha_path)
directory = str(Path(pha_path).parent)

if arf_path is None:
if pha.ancrfile != "none" and pha.ancrfile != "":
arf_path = find_file_or_compressed_in_dir(pha.ancrfile, directory)

if rmf_path is None:
if pha.respfile != "none" and pha.respfile != "":
rmf_path = find_file_or_compressed_in_dir(pha.respfile, directory)

if bkg_path is None:
if pha.backfile.lower() != "none" and pha.backfile != "":
bkg_path = find_file_or_compressed_in_dir(pha.backfile, directory)
def find_path(file_name: str, directory: str) -> str | None:
if file_name.lower() != "none" and file_name != "":
return find_file_or_compressed_in_dir(file_name, directory)
else:
return None

arf = DataARF.from_file(arf_path) if arf_path is not None else None
rmf = DataRMF.from_file(rmf_path) if rmf_path is not None else None
bkg = DataPHA.from_file(bkg_path) if bkg_path is not None else None
header = fits.getheader(pha_path, "SPECTRUM")
directory = str(Path(pha_path).parent)

metadata = {
"observation_file": pha_path,
"background_file": bkg_path,
"response_matrix_file": rmf_path,
"ancillary_response_file": arf_path,
}
arf_path = find_path(header.get("ANCRFILE", "none"), directory)
rmf_path = find_path(header.get("RESPFILE", "none"), directory)
bkg_path = find_path(header.get("BACKFILE", "none"), directory)

return pha, arf, rmf, bkg, metadata
return arf_path, rmf_path, bkg_path


def find_file_or_compressed_in_dir(path: str | Path, directory: str | Path) -> str:
Expand Down
9 changes: 4 additions & 5 deletions src/jaxspec/model/additive.py
Original file line number Diff line number Diff line change
Expand Up @@ -364,18 +364,17 @@ class Diskbb(AdditiveComponent):
`Diskpbb` with $p=0.75$

??? abstract "Parameters"
* $T_{\text{in}}$ : Temperature at inner disk radius $\left[ \mathrm{keV}\right]$
* $\text{norm}$ : $\cos i(r_{\text{in}}/d)^{2}$,
where $r_{\text{in}}$ is "an apparent" inner disk radius $\left[\text{km}\right]$,
* $d$ the distance to the source in units of $10 \text{kpc}$,
* $i$ the angle of the disk ($i=0$ is face-on)
* $T_{\text{in}}$ : Temperature at inner disk radius $\left[ \mathrm{keV}\right]$
$d$ the distance to the source in units of $10 \text{kpc}$, $i$ the angle of the disk ($i=0$ is face-on)
"""

def continuum(self, energy):
norm = hk.get_parameter("norm", [], init=HaikuConstant(1))
p = 0.75
tin = hk.get_parameter("Tin", [], init=HaikuConstant(1))
tout = 0.0
tin = hk.get_parameter("Tin", [], init=HaikuConstant(1))
norm = hk.get_parameter("norm", [], init=HaikuConstant(1))

# Tout is set to 0 as it is evaluated at R=infinity
def integrand(kT, e, tin, p):
Expand Down
Loading
Loading