From c12c37074739c29b2a003d9abc8946808abdf18e Mon Sep 17 00:00:00 2001 From: Gaurav17Joshi Date: Mon, 12 Jun 2023 18:04:07 +0530 Subject: [PATCH 01/50] Adding_gp_feature --- docs/changes/739.feature.rst | 1 + setup.cfg | 6 + stingray/modeling/gpmodeling.py | 426 ++++++++++++++++++++++++++++++++ 3 files changed, 433 insertions(+) create mode 100644 docs/changes/739.feature.rst create mode 100644 stingray/modeling/gpmodeling.py diff --git a/docs/changes/739.feature.rst b/docs/changes/739.feature.rst new file mode 100644 index 000000000..b1661b37c --- /dev/null +++ b/docs/changes/739.feature.rst @@ -0,0 +1 @@ +A feature dealing with Gaussian Processes for Qpo analysis \ No newline at end of file diff --git a/setup.cfg b/setup.cfg index 448f2937f..e5ac3a72f 100644 --- a/setup.cfg +++ b/setup.cfg @@ -22,6 +22,12 @@ install_requires = scipy>=1.1.0 ; Matplotlib 3.4.0 is incompatible with Astropy matplotlib>=3.0,!=3.4.0 + jax + tinygp + jaxns + etils + tensorflow_probability + typing_extensions [options.entry_points] console_scripts = diff --git a/stingray/modeling/gpmodeling.py b/stingray/modeling/gpmodeling.py new file mode 100644 index 000000000..1c2917a28 --- /dev/null +++ b/stingray/modeling/gpmodeling.py @@ -0,0 +1,426 @@ +import numpy as np +import matplotlib.pyplot as plt +import jax +import jax.numpy as jnp +import functools +import tensorflow_probability.substrates.jax as tfp + +from jax import jit, random + +from tinygp import GaussianProcess, kernels +from stingray import Lightcurve + +from jaxns import ExactNestedSampler +from jaxns import TerminationCondition + +# from jaxns import analytic_log_evidence +from jaxns import Prior, Model + +jax.config.update("jax_enable_x64", True) + +tfpd = tfp.distributions +tfpb = tfp.bijectors + +__all__ = ["GP", "GPResult"] + + +def get_kernel(kernel_type, kernel_params): + """ + Function for producing the kernel for the Gaussian Process. + Returns the selected Tinygp kernel + + Parameters + ---------- + kernel_type: string + The type of kernel to be used for the Gaussian Process + To be selected from the kernels already implemented + + kernel_params: dict + Dictionary containing the parameters for the kernel + Should contain the parameters for the selected kernel + + """ + if kernel_type == "QPO_plus_RN": + kernel = kernels.quasisep.Exp( + scale=1 / kernel_params["crn"], sigma=(kernel_params["arn"]) ** 0.5 + ) + kernels.quasisep.Celerite( + a=kernel_params["aqpo"], + b=0.0, + c=kernel_params["cqpo"], + d=2 * jnp.pi * kernel_params["freq"], + ) + return kernel + elif kernel_type == "RN": + kernel = kernels.quasisep.Exp( + scale=1 / kernel_params["crn"], sigma=(kernel_params["arn"]) ** 0.5 + ) + return kernel + + +def get_mean(mean_type, mean_params): + """ + Function for producing the mean for the Gaussian Process. + + Parameters + ---------- + mean_type: string + The type of mean to be used for the Gaussian Process + To be selected from the mean functions already implemented + + mean_params: dict + Dictionary containing the parameters for the mean + Should contain the parameters for the selected mean + + """ + if mean_type == "gaussian": + mean = functools.partial(_gaussian, mean_params=mean_params) + elif mean_type == "exponential": + mean = functools.partial(_exponential, mean_params=mean_params) + elif mean_type == "constant": + mean = functools.partial(_constant, mean_params=mean_params) + return mean + + +def _gaussian(t, mean_params): + return mean_params["A"] * jnp.exp( + -((t - mean_params["t0"]) ** 2) / (2 * (mean_params["sig"] ** 2)) + ) + + +def _exponential(t, mean_params): + return mean_params["A"] * jnp.exp(-jnp.abs((t - mean_params["t0"])) / mean_params["sig"]) + + +def _constant(t, mean_params): + return mean_params["A"] * jnp.ones_like(t) + + +class GP: + """ + Makes a GP object which takes in a Stingray.Lightcurve and fits a Gaussian + Process on the lightcurve data, for the given kernel. + + Parameters + ---------- + lc: Stingray.Lightcurve object + The lightcurve on which the gaussian process, is to be fitted + + Model_type: string tuple + Has two strings with the first being the name of the kernel type + and the secound being the mean type + + Model_parameter: dict, default = None + Dictionary conatining the parameters for the mean and kernel + The keys should be accourding to the selected kernel and mean + coressponding to the Model_type + By default, it takes a value None, and the kernel and mean are + then bulit using the pre-set parameters. + + Other Parameters + ---------------- + kernel: class: `TinyGp.kernel` object + The tinygp kernel for the GP + + mean: class: `TinyGp.mean` object + The tinygp mean for the GP + + maingp: class: `TinyGp.GaussianProcess` object + The tinygp gaussian process made on the lightcurve + + """ + + def __init__(self, Lc: Lightcurve, Model_type: tuple, Model_params: dict = None) -> None: + self.lc = Lc + self.Model_type = Model_type + self.Model_param = Model_params + self.kernel = get_kernel(self.Model_type[0], self.Model_param) + self.mean = get_mean(self.Model_type[1], self.Model_param) + self.maingp = GaussianProcess( + self.kernel, Lc.time, mean=self.mean, diag=Model_params["diag"] + ) + + def get_logprob(self): + """ + Returns the logprobability of the lightcurves counts for the + given kernel for the Gaussian Process + """ + cond = self.maingp.condition(self.lc.counts) + return cond.log_probability + + def get_model(self): + """ + Returns the model of the Gaussian Process + """ + return (self.Model_type, self.Model_param) + + def plot_kernel(self): + """ + Plots the kernel of the Gaussian Process + """ + X = self.lc.time + Y = self.kernel(X, np.array([0.0])) + plt.plot(X, Y) + plt.xlabel("distance") + plt.ylabel("Value") + plt.title("Kernel Function") + + def plot_originalgp(self, sample_no=1, seed=0): + """ + Plots samples obtained from the gaussian process for the kernel + + Parameters + ---------- + sample_no: int , default = 1 + Number of GP samples to be taken + + """ + X_test = self.lc.time + _, ax = plt.subplots(1, 1, figsize=(10, 3)) + y_samp = self.maingp.sample(jax.random.PRNGKey(seed), shape=(sample_no,)) + ax.plot(X_test, y_samp[0], "C0", lw=0.5, alpha=0.5, label="samples") + ax.plot(X_test, y_samp[1:].T, "C0", lw=0.5, alpha=0.5) + ax.set_xlabel("time") + ax.set_ylabel("counts") + ax.legend(loc="best") + + def plot_gp(self, sample_no=1, seed=0): + """ + Plots gaussian process, conditioned on the lightcurve + Also, plots the lightcurve along with it + + Parameters + ---------- + sample_no: int , default = 1 + Number of GP samples to be taken + + """ + X_test = self.lc.time + + _, ax = plt.subplots(1, 1, figsize=(10, 3)) + _, cond_gp = self.maingp.condition(self.lc.counts, X_test) + mu = cond_gp.mean + # std = np.sqrt(cond_gp.variance) + + ax.plot(self.lc.time, self.lc.counts, lw=2, color="blue", label="Lightcurve") + ax.plot(X_test, mu, "C1", label="Gaussian Process") + y_samp = cond_gp.sample(jax.random.PRNGKey(seed), shape=(sample_no,)) + ax.plot(X_test, y_samp[0], "C0", lw=0.5, alpha=0.5) + ax.set_xlabel("time") + ax.set_ylabel("counts") + ax.legend(loc="best") + + +def get_prior(kernel_type, mean_type, **kwargs): + """ + A prior generator function based on given values + + Parameters + ---------- + kwargs: + All possible keyword arguments to construct the prior. + + Returns + ------- + The Prior function. + The arguments of the prior function are in the order of + Kernel arguments (RN arguments, QPO arguments), + Mean arguments + Non Windowed arguments + + """ + kwargs["T"] = kwargs["Times"][-1] - kwargs["Times"][0] # Total time + kwargs["f"] = 1 / (kwargs["Times"][1] - kwargs["Times"][0]) # Sampling frequency + kwargs["min"] = jnp.min(kwargs["counts"]) + kwargs["max"] = jnp.max(kwargs["counts"]) + kwargs["span"] = kwargs["max"] - kwargs["min"] + + def RNprior_model(): + arn = yield Prior(tfpd.Uniform(0.1 * kwargs["span"], 2 * kwargs["span"]), name="arn") + crn = yield Prior(tfpd.Uniform(jnp.log(1 / kwargs["T"]), jnp.log(kwargs["f"])), name="crn") + + A = yield Prior(tfpd.Uniform(0.1 * kwargs["span"], 2 * kwargs["span"]), name="A") + t0 = yield Prior( + tfpd.Uniform( + kwargs["Times"][0] - 0.1 * kwargs["T"], kwargs["Times"][-1] + 0.1 * kwargs["T"] + ), + name="t0", + ) + sig = yield Prior(tfpd.Uniform(0.5 * 1 / kwargs["f"], 2 * kwargs["T"]), name="sig") + return arn, crn, A, t0, sig + + if (kernel_type == "RN") & ((mean_type == "gaussian") | (mean_type == "exponential")): + return RNprior_model + + def QPOprior_model(): + arn = yield Prior(tfpd.Uniform(0.1 * kwargs["span"], 2 * kwargs["span"]), name="arn") + crn = yield Prior(tfpd.Uniform(jnp.log(1 / kwargs["T"]), jnp.log(kwargs["f"])), name="crn") + aqpo = yield Prior(tfpd.Uniform(0.1 * kwargs["span"], 2 * kwargs["span"]), name="aqpo") + cqpo = yield Prior(tfpd.Uniform(1 / 10 / kwargs["T"], jnp.log(kwargs["f"])), name="cqpo") + freq = yield Prior(tfpd.Uniform(2 / kwargs["T"], kwargs["f"] / 2), name="freq") + + A = yield Prior(tfpd.Uniform(0.1 * kwargs["span"], 2 * kwargs["span"]), name="A") + t0 = yield Prior( + tfpd.Uniform( + kwargs["Times"][0] - 0.1 * kwargs["T"], kwargs["Times"][-1] + 0.1 * kwargs["T"] + ), + name="t0", + ) + sig = yield Prior(tfpd.Uniform(0.5 * 1 / kwargs["f"], 2 * kwargs["T"]), name="sig") + + return arn, crn, aqpo, cqpo, freq, A, t0, sig + + if (kernel_type == "QPO_plus_RN") & ((mean_type == "gaussian") | (mean_type == "exponential")): + return QPOprior_model + + +def get_likelihood(kernel_type, mean_type, **kwargs): + """ + A likelihood generator function based on given values + """ + + @jit + def RNlog_likelihood(arn, crn, A, t0, sig): + rnlikelihood_params = { + "arn": arn, + "crn": crn, + "aqpo": 0.0, + "cqpo": 0.0, + "freq": 0.0, + } + + mean_params = { + "A": A, + "t0": t0, + "sig": sig, + } + + kernel = get_kernel(kernel_type="RN", kernel_params=rnlikelihood_params) + + mean = get_mean(mean_type=mean_type, mean_params=mean_params) + + gp = GaussianProcess(kernel, kwargs["Times"], mean=mean) + return gp.log_probability(kwargs["counts"]) + + if (kernel_type == "RN") & ((mean_type == "gaussian") | (mean_type == "exponential")): + return RNlog_likelihood + + @jit + def QPOlog_likelihood(arn, crn, aqpo, cqpo, freq, A, t0, sig): + qpolikelihood_params = { + "arn": arn, + "crn": crn, + "aqpo": aqpo, + "cqpo": cqpo, + "freq": freq, + } + + mean_params = { + "A": A, + "t0": t0, + "sig": sig, + } + + kernel = get_kernel(kernel_type="RN", kernel_params=qpolikelihood_params) + mean = get_mean(mean_type=mean_type, mean_params=mean_params) + + gp = GaussianProcess(kernel, kwargs["Times"], mean=mean) + return gp.log_probability(kwargs["counts"]) + + if (kernel_type == "QPO_plus_RN") & ((mean_type == "gaussian") | (mean_type == "exponential")): + return QPOlog_likelihood + + +class GPResult: + """ + Makes a GP regressor for a given GP class and a prior over it. + Provides the sampled hyperparameters and tabulates their charachtersistics + Using jaxns for nested sampling and evidence analysis + + Parameters + ---------- + GP: class: GP + The initial GP class, on which we will apply our regressor. + + prior_type: string tuple + Has two strings with the first being the name of the kernel type + and the secound being the mean type for the prior + + prior_parameters: dict, default = None + Dictionary containing the parameters for the mean and kernel priors + The keys should be accourding to the selected kernel and mean + prior coressponding to the prior_type + By default, it takes a value None, and the kernel and mean priors are + then bulit using the pre-set parameters. + + Other Parameters + ---------------- + lc: Stingray.Lightcurve object + The lightcurve on which the gaussian process regression, is to be done + + """ + + def __init__(self, GP: GP, prior_type: tuple, prior_parameters=None) -> None: + self.gpclass = GP + self.prior_type = prior_type + self.prior_parameters = prior_parameters + self.lc = GP.lc + + def run_sampling(self): + """ + Runs a sampling process for the hyperparameters for the GP model. + Based on No U turn Sampling from the numpyro module + """ + + dict = {"Times": self.lc.time, "counts": self.lc.counts} + self.prior_model = get_prior(self.prior_type[0], self.prior_type[1], **dict) + self.likelihood = get_likelihood(self.prior_type[0], self.prior_type[1], **dict) + + NSmodel = Model(prior_model=self.prior_model, log_likelihood=self.likelihood) + + NSmodel.sanity_check(random.PRNGKey(10), S=100) + + self.Exact_ns = ExactNestedSampler(NSmodel, num_live_points=500, max_samples=1e4) + Termination_reason, State = self.Exact_ns( + random.PRNGKey(42), term_cond=TerminationCondition(live_evidence_frac=1e-4) + ) + self.Results = self.Exact_ns.to_results(State, Termination_reason) + + def print_summary(self): + """ + Prints a summary table for the model parameters + """ + self.Exact_ns.summary(self.Results) + + def plot_diagnostics(self): + """ + Plots the diagnostic plots for the sampling process + """ + self.Exact_ns.plot_diagnostics(self.Results) + + def corner_plot(self): + """ + Plots the corner plot for the sampled hyperparameters + """ + self.Exact_ns.plot_corner(self.Results) + + def get_parameters(self): + """ + Returns the optimal parameters for the model based on the NUTS sampling + """ + + pass + + def plot_posterior(self, X_test): + """ + Plots posterior gaussian process, conditioned on the lightcurve + Also, plots the lightcurve along with it + + Parameters + ---------- + X_test: jnp.array + Array over which the Gaussian process values are to be obtained + Can be made default with lc.times as default + + """ + + pass From 04b338b2ace4d69a44da9f60f299d1d58d51ad78 Mon Sep 17 00:00:00 2001 From: Gaurav17Joshi Date: Wed, 14 Jun 2023 18:43:55 +0530 Subject: [PATCH 02/50] Added skew means --- stingray/modeling/gpmodeling.py | 212 +++++++++++++++++++++++++++++++- 1 file changed, 209 insertions(+), 3 deletions(-) diff --git a/stingray/modeling/gpmodeling.py b/stingray/modeling/gpmodeling.py index 1c2917a28..8f8f0793a 100644 --- a/stingray/modeling/gpmodeling.py +++ b/stingray/modeling/gpmodeling.py @@ -78,23 +78,128 @@ def get_mean(mean_type, mean_params): mean = functools.partial(_exponential, mean_params=mean_params) elif mean_type == "constant": mean = functools.partial(_constant, mean_params=mean_params) + elif mean_type == "skew_gaussian": + mean = functools.partial(_skew_gaussian, mean_params=mean_params) + elif mean_type == "skew_exponential": + mean = functools.partial(_skew_exponential, mean_params=mean_params) return mean def _gaussian(t, mean_params): + """A gaussian flare shape. + + Parameters + ---------- + t: jnp.ndarray + The time coordinates. + A: jnp.int + Amplitude of the flare. + t0: + The location of the maximum. + sig1: + The width parameter for the gaussian. + + Returns + ------- + The y values for the gaussian flare. + """ return mean_params["A"] * jnp.exp( -((t - mean_params["t0"]) ** 2) / (2 * (mean_params["sig"] ** 2)) ) def _exponential(t, mean_params): + """An exponential flare shape. + + Parameters + ---------- + t: jnp.ndarray + The time coordinates. + A: jnp.int + Amplitude of the flare. + t0: + The location of the maximum. + sig1: + The width parameter for the exponential. + + Returns + ------- + The y values for exponential flare. + """ return mean_params["A"] * jnp.exp(-jnp.abs((t - mean_params["t0"])) / mean_params["sig"]) def _constant(t, mean_params): + """A constant mean shape. + + Parameters + ---------- + t: jnp.ndarray + The time coordinates. + A: jnp.int + Constant amplitude of the flare. + + Returns + ------- + The constant value. + """ return mean_params["A"] * jnp.ones_like(t) +def _skew_gaussian(t, mean_params): + """A skew gaussian flare shape. + + Parameters + ---------- + t: jnp.ndarray + The time coordinates. + A: jnp.int + Amplitude of the flare. + t0: + The location of the maximum. + sig1: + The width parameter for the rising edge. + sig2: + The width parameter for the falling edge. + + Returns + ------- + The y values for skew gaussian flare. + """ + return mean_params["A"] * jnp.where( + t > mean_params["t0"], + jnp.exp(-((t - mean_params["t0"]) ** 2) / (2 * (mean_params["sig2"] ** 2))), + jnp.exp(-((t - mean_params["t0"]) ** 2) / (2 * (mean_params["sig1"] ** 2))), + ) + + +def _skew_exponential(t, mean_params): + """A skew exponential flare shape. + + Parameters + ---------- + t: jnp.ndarray + The time coordinates. + A: jnp.int + Amplitude of the flare. + t0: + The location of the maximum. + sig1: + The width parameter for the rising edge. + sig2: + The width parameter for the falling edge. + + Returns + ------- + The y values for exponential flare. + """ + return mean_params["A"] * jnp.where( + t > mean_params["t0"], + jnp.exp(-(t - mean_params["t0"]) / mean_params["sig2"]), + jnp.exp((t - mean_params["t0"]) / mean_params["sig1"]), + ) + + class GP: """ Makes a GP object which takes in a Stingray.Lightcurve and fits a Gaussian @@ -272,6 +377,49 @@ def QPOprior_model(): if (kernel_type == "QPO_plus_RN") & ((mean_type == "gaussian") | (mean_type == "exponential")): return QPOprior_model + def skew_RNprior_model(): + arn = yield Prior(tfpd.Uniform(0.1 * kwargs["span"], 2 * kwargs["span"]), name="arn") + crn = yield Prior(tfpd.Uniform(jnp.log(1 / kwargs["T"]), jnp.log(kwargs["f"])), name="crn") + + A = yield Prior(tfpd.Uniform(0.1 * kwargs["span"], 2 * kwargs["span"]), name="A") + t0 = yield Prior( + tfpd.Uniform( + kwargs["Times"][0] - 0.1 * kwargs["T"], kwargs["Times"][-1] + 0.1 * kwargs["T"] + ), + name="t0", + ) + sig1 = yield Prior(tfpd.Uniform(0.5 * 1 / kwargs["f"], 2 * kwargs["T"]), name="sig1") + sig2 = yield Prior(tfpd.Uniform(0.5 * 1 / kwargs["f"], 2 * kwargs["T"]), name="sig2") + + return arn, crn, A, t0, sig1, sig2 + + if (kernel_type == "RN") & ((mean_type == "skew_gaussian") | (mean_type == "skew_exponential")): + return skew_RNprior_model + + def skew_QPOprior_model(): + arn = yield Prior(tfpd.Uniform(0.1 * kwargs["span"], 2 * kwargs["span"]), name="arn") + crn = yield Prior(tfpd.Uniform(jnp.log(1 / kwargs["T"]), jnp.log(kwargs["f"])), name="crn") + aqpo = yield Prior(tfpd.Uniform(0.1 * kwargs["span"], 2 * kwargs["span"]), name="aqpo") + cqpo = yield Prior(tfpd.Uniform(1 / 10 / kwargs["T"], jnp.log(kwargs["f"])), name="cqpo") + freq = yield Prior(tfpd.Uniform(2 / kwargs["T"], kwargs["f"] / 2), name="freq") + + A = yield Prior(tfpd.Uniform(0.1 * kwargs["span"], 2 * kwargs["span"]), name="A") + t0 = yield Prior( + tfpd.Uniform( + kwargs["Times"][0] - 0.1 * kwargs["T"], kwargs["Times"][-1] + 0.1 * kwargs["T"] + ), + name="t0", + ) + sig1 = yield Prior(tfpd.Uniform(0.5 * 1 / kwargs["f"], 2 * kwargs["T"]), name="sig1") + sig2 = yield Prior(tfpd.Uniform(0.5 * 1 / kwargs["f"], 2 * kwargs["T"]), name="sig2") + + return arn, crn, aqpo, cqpo, freq, A, t0, sig1, sig2 + + if (kernel_type == "QPO_plus_RN") & ( + (mean_type == "skew_gaussian") | (mean_type == "skew_exponential") + ): + return skew_QPOprior_model + def get_likelihood(kernel_type, mean_type, **kwargs): """ @@ -320,7 +468,7 @@ def QPOlog_likelihood(arn, crn, aqpo, cqpo, freq, A, t0, sig): "sig": sig, } - kernel = get_kernel(kernel_type="RN", kernel_params=qpolikelihood_params) + kernel = get_kernel(kernel_type="QPO_plus_RN", kernel_params=qpolikelihood_params) mean = get_mean(mean_type=mean_type, mean_params=mean_params) gp = GaussianProcess(kernel, kwargs["Times"], mean=mean) @@ -329,6 +477,63 @@ def QPOlog_likelihood(arn, crn, aqpo, cqpo, freq, A, t0, sig): if (kernel_type == "QPO_plus_RN") & ((mean_type == "gaussian") | (mean_type == "exponential")): return QPOlog_likelihood + @jit + def skew_RNlog_likelihood(arn, crn, A, t0, sig1, sig2): + rnlikelihood_params = { + "arn": arn, + "crn": crn, + "aqpo": 0.0, + "cqpo": 0.0, + "freq": 0.0, + } + + mean_params = { + "A": A, + "t0": t0, + "sig1": sig1, + "sig2": sig2, + } + + kernel = get_kernel(kernel_type="RN", kernel_params=rnlikelihood_params) + + # This could be causing problems + mean = get_mean(mean_type=mean_type, mean_params=mean_params) + + gp = GaussianProcess(kernel, kwargs["Times"], mean=mean) + return gp.log_probability(kwargs["counts"]) + + if (kernel_type == "RN") & ((mean_type == "gaussian") | (mean_type == "exponential")): + return skew_RNlog_likelihood + + @jit + def skewQPOlog_likelihood(arn, crn, aqpo, cqpo, freq, A, t0, sig1, sig2): + qpolikelihood_params = { + "arn": arn, + "crn": crn, + "aqpo": aqpo, + "cqpo": cqpo, + "freq": freq, + } + + mean_params = { + "A": A, + "t0": t0, + "sig1": sig1, + "sig2": sig2, + } + + kernel = get_kernel(kernel_type="QPO_plus_RN", kernel_params=qpolikelihood_params) + + mean = get_mean(mean_type=mean_type, mean_params=mean_params) + + gp = GaussianProcess(kernel, kwargs["Times"], mean=mean) + return gp.log_probability(kwargs["counts"]) + + if (kernel_type == "QPO_plus_RN") & ( + (mean_type == "skew_gaussian") | (mean_type == "skew_exponential") + ): + return skewQPOlog_likelihood + class GPResult: """ @@ -384,6 +589,7 @@ def run_sampling(self): random.PRNGKey(42), term_cond=TerminationCondition(live_evidence_frac=1e-4) ) self.Results = self.Exact_ns.to_results(State, Termination_reason) + print("Simulation Complete") def print_summary(self): """ @@ -397,11 +603,11 @@ def plot_diagnostics(self): """ self.Exact_ns.plot_diagnostics(self.Results) - def corner_plot(self): + def plot_cornerplot(self): """ Plots the corner plot for the sampled hyperparameters """ - self.Exact_ns.plot_corner(self.Results) + self.Exact_ns.plot_cornerplot(self.Results) def get_parameters(self): """ From 7a44249c2e06bf32db51a419cac1e65a9fd21569 Mon Sep 17 00:00:00 2001 From: Gaurav17Joshi Date: Mon, 19 Jun 2023 21:17:34 +0530 Subject: [PATCH 03/50] Added fred model --- stingray/modeling/gpmodeling.py | 133 ++++++++++++++++++++++++++++++++ 1 file changed, 133 insertions(+) diff --git a/stingray/modeling/gpmodeling.py b/stingray/modeling/gpmodeling.py index 8f8f0793a..5b04ff120 100644 --- a/stingray/modeling/gpmodeling.py +++ b/stingray/modeling/gpmodeling.py @@ -82,6 +82,8 @@ def get_mean(mean_type, mean_params): mean = functools.partial(_skew_gaussian, mean_params=mean_params) elif mean_type == "skew_exponential": mean = functools.partial(_skew_exponential, mean_params=mean_params) + elif mean_type == "fred": + mean = functools.partial(_fred, mean_params=mean_params) return mean @@ -200,6 +202,39 @@ def _skew_exponential(t, mean_params): ) +def _fred(t, mean_params): + """A fast rise exponential decay (FRED) flare shape. + + Parameters + ---------- + t: jnp.ndarray + The time coordinates. + A: jnp.int + Amplitude of the flare. + t0: + The location of the maximum. + phi: + Symmetry parameter of the flare. + delta: + Offset parameter of the flare. + + Returns + ------- + The y values for exponential flare. + """ + return ( + mean_params["A"] + * jnp.exp( + -mean_params["phi"] + * ( + (t + mean_params["delta"]) / mean_params["t0"] + + mean_params["t0"] / (t + mean_params["delta"]) + ) + ) + * jnp.exp(2 * mean_params["phi"]) + ) + + class GP: """ Makes a GP object which takes in a Stingray.Lightcurve and fits a Gaussian @@ -420,6 +455,47 @@ def skew_QPOprior_model(): ): return skew_QPOprior_model + def fred_RNprior_model(): + arn = yield Prior(tfpd.Uniform(0.1 * kwargs["span"], 2 * kwargs["span"]), name="arn") + crn = yield Prior(tfpd.Uniform(jnp.log(1 / kwargs["T"]), jnp.log(kwargs["f"])), name="crn") + + A = yield Prior(tfpd.Uniform(0.1 * kwargs["span"], 2 * kwargs["span"]), name="A") + t0 = yield Prior( + tfpd.Uniform( + kwargs["Times"][0] - 0.1 * kwargs["T"], kwargs["Times"][-1] + 0.1 * kwargs["T"] + ), + name="t0", + ) + phi = yield Prior(tfpd.Uniform(2 * jnp.exp(-2), 2 * jnp.exp(4)), name="phi") + delta = yield Prior(tfpd.Uniform(0, kwargs["Times"][-1] / 2), name="delta") + + return arn, crn, A, t0, phi, delta + + if (kernel_type == "RN") & (mean_type == "fred"): + return fred_RNprior_model + + def fred_QPOprior_model(): + arn = yield Prior(tfpd.Uniform(0.1 * kwargs["span"], 2 * kwargs["span"]), name="arn") + crn = yield Prior(tfpd.Uniform(jnp.log(1 / kwargs["T"]), jnp.log(kwargs["f"])), name="crn") + aqpo = yield Prior(tfpd.Uniform(0.1 * kwargs["span"], 2 * kwargs["span"]), name="aqpo") + cqpo = yield Prior(tfpd.Uniform(1 / 10 / kwargs["T"], jnp.log(kwargs["f"])), name="cqpo") + freq = yield Prior(tfpd.Uniform(2 / kwargs["T"], kwargs["f"] / 2), name="freq") + + A = yield Prior(tfpd.Uniform(0.1 * kwargs["span"], 2 * kwargs["span"]), name="A") + t0 = yield Prior( + tfpd.Uniform( + kwargs["Times"][0] - 0.1 * kwargs["T"], kwargs["Times"][-1] + 0.1 * kwargs["T"] + ), + name="t0", + ) + phi = yield Prior(tfpd.Uniform(2 * jnp.exp(-2), 2 * jnp.exp(4)), name="phi") + delta = yield Prior(tfpd.Uniform(0, kwargs["Times"][-1] / 2), name="delta") + + return arn, crn, aqpo, cqpo, freq, A, t0, phi, delta + + if (kernel_type == "QPO_plus_RN") & (mean_type == "fred"): + return fred_QPOprior_model + def get_likelihood(kernel_type, mean_type, **kwargs): """ @@ -534,6 +610,63 @@ def skewQPOlog_likelihood(arn, crn, aqpo, cqpo, freq, A, t0, sig1, sig2): ): return skewQPOlog_likelihood + @jit + def fred_RNlog_likelihood(arn, crn, A, t0, phi, delta): + rnlikelihood_params = { + "arn": arn, + "crn": crn, + "aqpo": 0.0, + "cqpo": 0.0, + "freq": 0.0, + } + + mean_params = { + "A": A, + "t0": t0, + "phi": phi, + "delta": delta, + } + + kernel = get_kernel(kernel_type="RN", kernel_params=rnlikelihood_params) + + # This could be causing problems + mean = get_mean(mean_type=mean_type, mean_params=mean_params) + + # gp = GaussianProcess(kernel, kwargs["Times"], mean=mean) + gp = GaussianProcess(kernel, kwargs["Times"], mean_value=mean(kwargs["Times"])) + return gp.log_probability(kwargs["counts"]) + + if (kernel_type == "RN") & (mean_type == "fred"): + return fred_RNlog_likelihood + + @jit + def fredQPOlog_likelihood(arn, crn, aqpo, cqpo, freq, A, t0, phi, delta): + qpolikelihood_params = { + "arn": arn, + "crn": crn, + "aqpo": aqpo, + "cqpo": cqpo, + "freq": freq, + } + + mean_params = { + "A": A, + "t0": t0, + "phi": phi, + "delta": delta, + } + + kernel = get_kernel(kernel_type="QPO_plus_RN", kernel_params=qpolikelihood_params) + + mean = get_mean(mean_type=mean_type, mean_params=mean_params) + + # gp = GaussianProcess(kernel, kwargs["Times"], mean=mean) + gp = GaussianProcess(kernel, kwargs["Times"], mean_value=mean(kwargs["Times"])) + return gp.log_probability(kwargs["counts"]) + + if (kernel_type == "QPO_plus_RN") & (mean_type == "fred"): + return fredQPOlog_likelihood + class GPResult: """ From 3bb4972cd9ab5e67e63417d1ea2da42b0bf19d1d Mon Sep 17 00:00:00 2001 From: Gaurav17Joshi Date: Tue, 27 Jun 2023 16:05:57 +0530 Subject: [PATCH 04/50] Combined GP, GPR class --- stingray/modeling/gpmodeling.py | 492 ++++++++------------------------ 1 file changed, 120 insertions(+), 372 deletions(-) diff --git a/stingray/modeling/gpmodeling.py b/stingray/modeling/gpmodeling.py index 5b04ff120..ab810caf3 100644 --- a/stingray/modeling/gpmodeling.py +++ b/stingray/modeling/gpmodeling.py @@ -21,7 +21,7 @@ tfpd = tfp.distributions tfpb = tfp.bijectors -__all__ = ["GP", "GPResult"] +__all__ = ["GP"] def get_kernel(kernel_type, kernel_params): @@ -235,6 +235,99 @@ def _fred(t, mean_params): ) +def get_kernel_params(kernel_type): + if kernel_type == "RN": + return ["arn", "crn"] + elif kernel_type == "QPO_plus_RN": + return ["arn", "crn", "aqpo", "cqpo", "freq"] + + +def get_mean_params(mean_type): + if (mean_type == "gaussian") or (mean_type == "exponential"): + return ["A", "t0", "sig"] + elif mean_type == "constant": + return ["A"] + elif (mean_type == "skew_gaussian") or (mean_type == "skew_exponential"): + return ["A", "t0", "sig1", "sig2"] + elif mean_type == "fred": + return ["A", "t0", "delta", "phi"] + + +def get_gp_params(kernel_type, mean_type): + kernel_params = get_kernel_params(kernel_type) + mean_params = get_mean_params(mean_type) + kernel_params.extend(mean_params) + return kernel_params + + +def get_prior(params_list, prior_dict): + """ + A prior generator function based on given values + + Parameters + ---------- + params_list: + A list in order of the parameters to be used. + + prior_dict: + A dictionary of the priors of parameters to be used. + + Returns + ------- + The Prior function. + The arguments of the prior function are in the order of + Kernel arguments (RN arguments, QPO arguments), + Mean arguments + Non Windowed arguments + + """ + + def prior_model(): + prior_list = [] + for i in params_list: + if isinstance(prior_dict[i], tfpd.Distribution): + parameter = yield Prior(prior_dict[i], name=i) + else: + parameter = yield prior_dict[i] + prior_list.append(parameter) + return tuple(prior_list) + + return prior_model + + +def get_likelihood(params_list, kernel_type, mean_type, **kwargs): + """ + A likelihood generator function based on given values + + Parameters + ---------- + params_list: + A list in order of the parameters to be used. + + prior_dict: + A dictionary of the priors of parameters to be used. + + kernel_type: + The type of kernel to be used in the model. + + mean_type: + The type of mean to be used in the model. + + """ + + @jit + def likelihood_model(*args): + dict = {} + for i, params in enumerate(params_list): + dict[params] = args[i] + kernel = get_kernel(kernel_type=kernel_type, kernel_params=dict) + mean = get_mean(mean_type=mean_type, mean_params=dict) + gp = GaussianProcess(kernel, kwargs["Times"], mean_value=mean(kwargs["Times"])) + return gp.log_probability(kwargs["counts"]) + + return likelihood_model + + class GP: """ Makes a GP object which takes in a Stingray.Lightcurve and fits a Gaussian @@ -269,14 +362,16 @@ class GP: """ - def __init__(self, Lc: Lightcurve, Model_type: tuple, Model_params: dict = None) -> None: + def __init__(self, Lc: Lightcurve) -> None: self.lc = Lc - self.Model_type = Model_type - self.Model_param = Model_params - self.kernel = get_kernel(self.Model_type[0], self.Model_param) - self.mean = get_mean(self.Model_type[1], self.Model_param) + self.time = Lc.time + self.counts = Lc.counts + + def fit(self, kernel=None, mean=None, **kwargs): + self.kernel = kernel + self.mean = mean self.maingp = GaussianProcess( - self.kernel, Lc.time, mean=self.mean, diag=Model_params["diag"] + self.kernel, self.time, mean_value=self.mean(self.time), diag=kwargs["diag"] ) def get_logprob(self): @@ -287,12 +382,6 @@ def get_logprob(self): cond = self.maingp.condition(self.lc.counts) return cond.log_probability - def get_model(self): - """ - Returns the model of the Gaussian Process - """ - return (self.Model_type, self.Model_param) - def plot_kernel(self): """ Plots the kernel of the Gaussian Process @@ -349,372 +438,31 @@ def plot_gp(self, sample_no=1, seed=0): ax.set_ylabel("counts") ax.legend(loc="best") + def sample(self, prior_model=None, likelihood_model=None, **kwargs): + """ + Makes a Jaxns nested sampler over the Gaussian Process, given the + prior and likelihood model -def get_prior(kernel_type, mean_type, **kwargs): - """ - A prior generator function based on given values - - Parameters - ---------- - kwargs: - All possible keyword arguments to construct the prior. - - Returns - ------- - The Prior function. - The arguments of the prior function are in the order of - Kernel arguments (RN arguments, QPO arguments), - Mean arguments - Non Windowed arguments - - """ - kwargs["T"] = kwargs["Times"][-1] - kwargs["Times"][0] # Total time - kwargs["f"] = 1 / (kwargs["Times"][1] - kwargs["Times"][0]) # Sampling frequency - kwargs["min"] = jnp.min(kwargs["counts"]) - kwargs["max"] = jnp.max(kwargs["counts"]) - kwargs["span"] = kwargs["max"] - kwargs["min"] - - def RNprior_model(): - arn = yield Prior(tfpd.Uniform(0.1 * kwargs["span"], 2 * kwargs["span"]), name="arn") - crn = yield Prior(tfpd.Uniform(jnp.log(1 / kwargs["T"]), jnp.log(kwargs["f"])), name="crn") - - A = yield Prior(tfpd.Uniform(0.1 * kwargs["span"], 2 * kwargs["span"]), name="A") - t0 = yield Prior( - tfpd.Uniform( - kwargs["Times"][0] - 0.1 * kwargs["T"], kwargs["Times"][-1] + 0.1 * kwargs["T"] - ), - name="t0", - ) - sig = yield Prior(tfpd.Uniform(0.5 * 1 / kwargs["f"], 2 * kwargs["T"]), name="sig") - return arn, crn, A, t0, sig - - if (kernel_type == "RN") & ((mean_type == "gaussian") | (mean_type == "exponential")): - return RNprior_model - - def QPOprior_model(): - arn = yield Prior(tfpd.Uniform(0.1 * kwargs["span"], 2 * kwargs["span"]), name="arn") - crn = yield Prior(tfpd.Uniform(jnp.log(1 / kwargs["T"]), jnp.log(kwargs["f"])), name="crn") - aqpo = yield Prior(tfpd.Uniform(0.1 * kwargs["span"], 2 * kwargs["span"]), name="aqpo") - cqpo = yield Prior(tfpd.Uniform(1 / 10 / kwargs["T"], jnp.log(kwargs["f"])), name="cqpo") - freq = yield Prior(tfpd.Uniform(2 / kwargs["T"], kwargs["f"] / 2), name="freq") - - A = yield Prior(tfpd.Uniform(0.1 * kwargs["span"], 2 * kwargs["span"]), name="A") - t0 = yield Prior( - tfpd.Uniform( - kwargs["Times"][0] - 0.1 * kwargs["T"], kwargs["Times"][-1] + 0.1 * kwargs["T"] - ), - name="t0", - ) - sig = yield Prior(tfpd.Uniform(0.5 * 1 / kwargs["f"], 2 * kwargs["T"]), name="sig") - - return arn, crn, aqpo, cqpo, freq, A, t0, sig - - if (kernel_type == "QPO_plus_RN") & ((mean_type == "gaussian") | (mean_type == "exponential")): - return QPOprior_model - - def skew_RNprior_model(): - arn = yield Prior(tfpd.Uniform(0.1 * kwargs["span"], 2 * kwargs["span"]), name="arn") - crn = yield Prior(tfpd.Uniform(jnp.log(1 / kwargs["T"]), jnp.log(kwargs["f"])), name="crn") - - A = yield Prior(tfpd.Uniform(0.1 * kwargs["span"], 2 * kwargs["span"]), name="A") - t0 = yield Prior( - tfpd.Uniform( - kwargs["Times"][0] - 0.1 * kwargs["T"], kwargs["Times"][-1] + 0.1 * kwargs["T"] - ), - name="t0", - ) - sig1 = yield Prior(tfpd.Uniform(0.5 * 1 / kwargs["f"], 2 * kwargs["T"]), name="sig1") - sig2 = yield Prior(tfpd.Uniform(0.5 * 1 / kwargs["f"], 2 * kwargs["T"]), name="sig2") - - return arn, crn, A, t0, sig1, sig2 - - if (kernel_type == "RN") & ((mean_type == "skew_gaussian") | (mean_type == "skew_exponential")): - return skew_RNprior_model - - def skew_QPOprior_model(): - arn = yield Prior(tfpd.Uniform(0.1 * kwargs["span"], 2 * kwargs["span"]), name="arn") - crn = yield Prior(tfpd.Uniform(jnp.log(1 / kwargs["T"]), jnp.log(kwargs["f"])), name="crn") - aqpo = yield Prior(tfpd.Uniform(0.1 * kwargs["span"], 2 * kwargs["span"]), name="aqpo") - cqpo = yield Prior(tfpd.Uniform(1 / 10 / kwargs["T"], jnp.log(kwargs["f"])), name="cqpo") - freq = yield Prior(tfpd.Uniform(2 / kwargs["T"], kwargs["f"] / 2), name="freq") - - A = yield Prior(tfpd.Uniform(0.1 * kwargs["span"], 2 * kwargs["span"]), name="A") - t0 = yield Prior( - tfpd.Uniform( - kwargs["Times"][0] - 0.1 * kwargs["T"], kwargs["Times"][-1] + 0.1 * kwargs["T"] - ), - name="t0", - ) - sig1 = yield Prior(tfpd.Uniform(0.5 * 1 / kwargs["f"], 2 * kwargs["T"]), name="sig1") - sig2 = yield Prior(tfpd.Uniform(0.5 * 1 / kwargs["f"], 2 * kwargs["T"]), name="sig2") - - return arn, crn, aqpo, cqpo, freq, A, t0, sig1, sig2 - - if (kernel_type == "QPO_plus_RN") & ( - (mean_type == "skew_gaussian") | (mean_type == "skew_exponential") - ): - return skew_QPOprior_model - - def fred_RNprior_model(): - arn = yield Prior(tfpd.Uniform(0.1 * kwargs["span"], 2 * kwargs["span"]), name="arn") - crn = yield Prior(tfpd.Uniform(jnp.log(1 / kwargs["T"]), jnp.log(kwargs["f"])), name="crn") - - A = yield Prior(tfpd.Uniform(0.1 * kwargs["span"], 2 * kwargs["span"]), name="A") - t0 = yield Prior( - tfpd.Uniform( - kwargs["Times"][0] - 0.1 * kwargs["T"], kwargs["Times"][-1] + 0.1 * kwargs["T"] - ), - name="t0", - ) - phi = yield Prior(tfpd.Uniform(2 * jnp.exp(-2), 2 * jnp.exp(4)), name="phi") - delta = yield Prior(tfpd.Uniform(0, kwargs["Times"][-1] / 2), name="delta") - - return arn, crn, A, t0, phi, delta - - if (kernel_type == "RN") & (mean_type == "fred"): - return fred_RNprior_model - - def fred_QPOprior_model(): - arn = yield Prior(tfpd.Uniform(0.1 * kwargs["span"], 2 * kwargs["span"]), name="arn") - crn = yield Prior(tfpd.Uniform(jnp.log(1 / kwargs["T"]), jnp.log(kwargs["f"])), name="crn") - aqpo = yield Prior(tfpd.Uniform(0.1 * kwargs["span"], 2 * kwargs["span"]), name="aqpo") - cqpo = yield Prior(tfpd.Uniform(1 / 10 / kwargs["T"], jnp.log(kwargs["f"])), name="cqpo") - freq = yield Prior(tfpd.Uniform(2 / kwargs["T"], kwargs["f"] / 2), name="freq") - - A = yield Prior(tfpd.Uniform(0.1 * kwargs["span"], 2 * kwargs["span"]), name="A") - t0 = yield Prior( - tfpd.Uniform( - kwargs["Times"][0] - 0.1 * kwargs["T"], kwargs["Times"][-1] + 0.1 * kwargs["T"] - ), - name="t0", - ) - phi = yield Prior(tfpd.Uniform(2 * jnp.exp(-2), 2 * jnp.exp(4)), name="phi") - delta = yield Prior(tfpd.Uniform(0, kwargs["Times"][-1] / 2), name="delta") - - return arn, crn, aqpo, cqpo, freq, A, t0, phi, delta - - if (kernel_type == "QPO_plus_RN") & (mean_type == "fred"): - return fred_QPOprior_model - - -def get_likelihood(kernel_type, mean_type, **kwargs): - """ - A likelihood generator function based on given values - """ - - @jit - def RNlog_likelihood(arn, crn, A, t0, sig): - rnlikelihood_params = { - "arn": arn, - "crn": crn, - "aqpo": 0.0, - "cqpo": 0.0, - "freq": 0.0, - } - - mean_params = { - "A": A, - "t0": t0, - "sig": sig, - } - - kernel = get_kernel(kernel_type="RN", kernel_params=rnlikelihood_params) - - mean = get_mean(mean_type=mean_type, mean_params=mean_params) - - gp = GaussianProcess(kernel, kwargs["Times"], mean=mean) - return gp.log_probability(kwargs["counts"]) - - if (kernel_type == "RN") & ((mean_type == "gaussian") | (mean_type == "exponential")): - return RNlog_likelihood - - @jit - def QPOlog_likelihood(arn, crn, aqpo, cqpo, freq, A, t0, sig): - qpolikelihood_params = { - "arn": arn, - "crn": crn, - "aqpo": aqpo, - "cqpo": cqpo, - "freq": freq, - } - - mean_params = { - "A": A, - "t0": t0, - "sig": sig, - } - - kernel = get_kernel(kernel_type="QPO_plus_RN", kernel_params=qpolikelihood_params) - mean = get_mean(mean_type=mean_type, mean_params=mean_params) - - gp = GaussianProcess(kernel, kwargs["Times"], mean=mean) - return gp.log_probability(kwargs["counts"]) - - if (kernel_type == "QPO_plus_RN") & ((mean_type == "gaussian") | (mean_type == "exponential")): - return QPOlog_likelihood - - @jit - def skew_RNlog_likelihood(arn, crn, A, t0, sig1, sig2): - rnlikelihood_params = { - "arn": arn, - "crn": crn, - "aqpo": 0.0, - "cqpo": 0.0, - "freq": 0.0, - } - - mean_params = { - "A": A, - "t0": t0, - "sig1": sig1, - "sig2": sig2, - } - - kernel = get_kernel(kernel_type="RN", kernel_params=rnlikelihood_params) - - # This could be causing problems - mean = get_mean(mean_type=mean_type, mean_params=mean_params) - - gp = GaussianProcess(kernel, kwargs["Times"], mean=mean) - return gp.log_probability(kwargs["counts"]) - - if (kernel_type == "RN") & ((mean_type == "gaussian") | (mean_type == "exponential")): - return skew_RNlog_likelihood - - @jit - def skewQPOlog_likelihood(arn, crn, aqpo, cqpo, freq, A, t0, sig1, sig2): - qpolikelihood_params = { - "arn": arn, - "crn": crn, - "aqpo": aqpo, - "cqpo": cqpo, - "freq": freq, - } - - mean_params = { - "A": A, - "t0": t0, - "sig1": sig1, - "sig2": sig2, - } - - kernel = get_kernel(kernel_type="QPO_plus_RN", kernel_params=qpolikelihood_params) - - mean = get_mean(mean_type=mean_type, mean_params=mean_params) - - gp = GaussianProcess(kernel, kwargs["Times"], mean=mean) - return gp.log_probability(kwargs["counts"]) - - if (kernel_type == "QPO_plus_RN") & ( - (mean_type == "skew_gaussian") | (mean_type == "skew_exponential") - ): - return skewQPOlog_likelihood - - @jit - def fred_RNlog_likelihood(arn, crn, A, t0, phi, delta): - rnlikelihood_params = { - "arn": arn, - "crn": crn, - "aqpo": 0.0, - "cqpo": 0.0, - "freq": 0.0, - } - - mean_params = { - "A": A, - "t0": t0, - "phi": phi, - "delta": delta, - } - - kernel = get_kernel(kernel_type="RN", kernel_params=rnlikelihood_params) - - # This could be causing problems - mean = get_mean(mean_type=mean_type, mean_params=mean_params) - - # gp = GaussianProcess(kernel, kwargs["Times"], mean=mean) - gp = GaussianProcess(kernel, kwargs["Times"], mean_value=mean(kwargs["Times"])) - return gp.log_probability(kwargs["counts"]) - - if (kernel_type == "RN") & (mean_type == "fred"): - return fred_RNlog_likelihood - - @jit - def fredQPOlog_likelihood(arn, crn, aqpo, cqpo, freq, A, t0, phi, delta): - qpolikelihood_params = { - "arn": arn, - "crn": crn, - "aqpo": aqpo, - "cqpo": cqpo, - "freq": freq, - } - - mean_params = { - "A": A, - "t0": t0, - "phi": phi, - "delta": delta, - } - - kernel = get_kernel(kernel_type="QPO_plus_RN", kernel_params=qpolikelihood_params) - - mean = get_mean(mean_type=mean_type, mean_params=mean_params) - - # gp = GaussianProcess(kernel, kwargs["Times"], mean=mean) - gp = GaussianProcess(kernel, kwargs["Times"], mean_value=mean(kwargs["Times"])) - return gp.log_probability(kwargs["counts"]) - - if (kernel_type == "QPO_plus_RN") & (mean_type == "fred"): - return fredQPOlog_likelihood - - -class GPResult: - """ - Makes a GP regressor for a given GP class and a prior over it. - Provides the sampled hyperparameters and tabulates their charachtersistics - Using jaxns for nested sampling and evidence analysis - - Parameters - ---------- - GP: class: GP - The initial GP class, on which we will apply our regressor. - - prior_type: string tuple - Has two strings with the first being the name of the kernel type - and the secound being the mean type for the prior - - prior_parameters: dict, default = None - Dictionary containing the parameters for the mean and kernel priors - The keys should be accourding to the selected kernel and mean - prior coressponding to the prior_type - By default, it takes a value None, and the kernel and mean priors are - then bulit using the pre-set parameters. - - Other Parameters - ---------------- - lc: Stingray.Lightcurve object - The lightcurve on which the gaussian process regression, is to be done + Parameters + ---------- + prior_model: jaxns.prior.PriorModelType object + A prior generator object - """ + likelihood_model: jaxns.types.LikelihoodType object + A likelihood fucntion which takes in the arguments of the prior + model and returns the loglikelihood of the model - def __init__(self, GP: GP, prior_type: tuple, prior_parameters=None) -> None: - self.gpclass = GP - self.prior_type = prior_type - self.prior_parameters = prior_parameters - self.lc = GP.lc + Returns + ---------- + Results: jaxns.results.NestedSamplerResults object + The results of the nested sampling process - def run_sampling(self): """ - Runs a sampling process for the hyperparameters for the GP model. - Based on No U turn Sampling from the numpyro module - """ - - dict = {"Times": self.lc.time, "counts": self.lc.counts} - self.prior_model = get_prior(self.prior_type[0], self.prior_type[1], **dict) - self.likelihood = get_likelihood(self.prior_type[0], self.prior_type[1], **dict) - NSmodel = Model(prior_model=self.prior_model, log_likelihood=self.likelihood) + self.prior_model = prior_model + self.likelihood_model = likelihood_model + NSmodel = Model(prior_model=self.prior_model, log_likelihood=self.likelihood_model) NSmodel.sanity_check(random.PRNGKey(10), S=100) self.Exact_ns = ExactNestedSampler(NSmodel, num_live_points=500, max_samples=1e4) From 9fe6a4715d825f780f7e19ce5fb3cfcd1b355e87 Mon Sep 17 00:00:00 2001 From: Gaurav17Joshi Date: Tue, 27 Jun 2023 23:01:31 +0530 Subject: [PATCH 05/50] Added kernel, mean tests --- stingray/modeling/gpmodeling.py | 14 +++-- stingray/modeling/tests/test_gpmodeling.py | 60 ++++++++++++++++++++++ 2 files changed, 70 insertions(+), 4 deletions(-) create mode 100644 stingray/modeling/tests/test_gpmodeling.py diff --git a/stingray/modeling/gpmodeling.py b/stingray/modeling/gpmodeling.py index ab810caf3..a26b74a09 100644 --- a/stingray/modeling/gpmodeling.py +++ b/stingray/modeling/gpmodeling.py @@ -105,9 +105,11 @@ def _gaussian(t, mean_params): ------- The y values for the gaussian flare. """ - return mean_params["A"] * jnp.exp( - -((t - mean_params["t0"]) ** 2) / (2 * (mean_params["sig"] ** 2)) - ) + A = jnp.atleast_1d(mean_params["A"])[:, jnp.newaxis] + t0 = jnp.atleast_1d(mean_params["t0"])[:, jnp.newaxis] + sig = jnp.atleast_1d(mean_params["sig"])[:, jnp.newaxis] + + return jnp.sum(A * jnp.exp(-((t - t0) ** 2) / (2 * (sig**2))), axis=0) def _exponential(t, mean_params): @@ -128,7 +130,11 @@ def _exponential(t, mean_params): ------- The y values for exponential flare. """ - return mean_params["A"] * jnp.exp(-jnp.abs((t - mean_params["t0"])) / mean_params["sig"]) + A = jnp.atleast_1d(mean_params["A"])[:, jnp.newaxis] + t0 = jnp.atleast_1d(mean_params["t0"])[:, jnp.newaxis] + sig = jnp.atleast_1d(mean_params["sig"])[:, jnp.newaxis] + + return jnp.sum(A * jnp.exp(-jnp.abs(t - t0) / (2 * (sig**2))), axis=0) def _constant(t, mean_params): diff --git a/stingray/modeling/tests/test_gpmodeling.py b/stingray/modeling/tests/test_gpmodeling.py new file mode 100644 index 000000000..19e2305dc --- /dev/null +++ b/stingray/modeling/tests/test_gpmodeling.py @@ -0,0 +1,60 @@ +import jax +import jax.numpy as jnp +import numpy as np +import tensorflow_probability.substrates.jax as tfp +import matplotlib.pyplot as plt + +from tinygp import GaussianProcess, kernels +from stingray.modeling.gpmodeling import get_kernel, get_mean, GP + + +class Testget_kernel(object): + def setup_class(self): + self.x = np.linspace(0, 1, 5) + self.kernel_params = {"arn": 1.0, "aqpo": 1.0, "crn": 1.0, "cqpo": 1.0, "freq": 1.0} + + def test_get_kernel_qpo_plus_rn(self): + kernel_qpo_plus_rn = kernels.quasisep.Exp( + scale=1 / 1, sigma=(1) ** 0.5 + ) + kernels.quasisep.Celerite( + a=1, + b=0.0, + c=1, + d=2 * jnp.pi * 1, + ) + kernel_qpo_plus_rn_test = get_kernel("QPO_plus_RN", self.kernel_params) + assert ( + kernel_qpo_plus_rn(self.x, jnp.array([0.0])) + == kernel_qpo_plus_rn_test(self.x, jnp.array([0.0])) + ).all() + + def test_get_kernel_rn(self): + kernel_rn = kernels.quasisep.Exp(scale=1 / 1, sigma=(1) ** 0.5) + kernel_rn_test = get_kernel("RN", self.kernel_params) + assert ( + kernel_rn(self.x, jnp.array([0.0])) == kernel_rn_test(self.x, jnp.array([0.0])) + ).all() + + +class Testget_mean(object): + def setup_class(self): + self.t = np.linspace(0, 5, 10) + self.mean_params_gaussian = { + "A": jnp.array([3.0, 4.0]), + "t0": jnp.array([0.2, 0.7]), + "sig": jnp.array([0.2, 0.1]), + } + + def test_get_mean_gaussian(self): + result_gaussian = 3 * jnp.exp(-((self.t - 0.2) ** 2) / (2 * (0.2**2))) + 4 * jnp.exp( + -((self.t - 0.7) ** 2) / (2 * (0.1**2)) + ) + assert (get_mean("gaussian", self.mean_params_gaussian)(self.t) == result_gaussian).all() + + def test_get_mean_exponential(self): + result_exponential = 3 * jnp.exp(-jnp.abs(self.t - 0.2) / (2 * (0.2**2))) + 4 * jnp.exp( + -jnp.abs(self.t - 0.7) / (2 * (0.1**2)) + ) + assert ( + get_mean("exponential", self.mean_params_gaussian)(self.t) == result_exponential + ).all() From 9b77f2a79c4a843a5066cf88c424fc05e1b2aadc Mon Sep 17 00:00:00 2001 From: Gaurav17Joshi Date: Wed, 28 Jun 2023 01:49:25 +0530 Subject: [PATCH 06/50] Added Multimean and tests --- stingray/modeling/gpmodeling.py | 51 +++++++++++++------- stingray/modeling/tests/test_gpmodeling.py | 55 ++++++++++++++++++++-- 2 files changed, 85 insertions(+), 21 deletions(-) diff --git a/stingray/modeling/gpmodeling.py b/stingray/modeling/gpmodeling.py index a26b74a09..2c9426060 100644 --- a/stingray/modeling/gpmodeling.py +++ b/stingray/modeling/gpmodeling.py @@ -174,10 +174,19 @@ def _skew_gaussian(t, mean_params): ------- The y values for skew gaussian flare. """ - return mean_params["A"] * jnp.where( - t > mean_params["t0"], - jnp.exp(-((t - mean_params["t0"]) ** 2) / (2 * (mean_params["sig2"] ** 2))), - jnp.exp(-((t - mean_params["t0"]) ** 2) / (2 * (mean_params["sig1"] ** 2))), + A = jnp.atleast_1d(mean_params["A"])[:, jnp.newaxis] + t0 = jnp.atleast_1d(mean_params["t0"])[:, jnp.newaxis] + sig1 = jnp.atleast_1d(mean_params["sig1"])[:, jnp.newaxis] + sig2 = jnp.atleast_1d(mean_params["sig2"])[:, jnp.newaxis] + + return jnp.sum( + A + * jnp.where( + t > t0, + jnp.exp(-((t - t0) ** 2) / (2 * (sig2**2))), + jnp.exp(-((t - t0) ** 2) / (2 * (sig1**2))), + ), + axis=0, ) @@ -201,10 +210,19 @@ def _skew_exponential(t, mean_params): ------- The y values for exponential flare. """ - return mean_params["A"] * jnp.where( - t > mean_params["t0"], - jnp.exp(-(t - mean_params["t0"]) / mean_params["sig2"]), - jnp.exp((t - mean_params["t0"]) / mean_params["sig1"]), + A = jnp.atleast_1d(mean_params["A"])[:, jnp.newaxis] + t0 = jnp.atleast_1d(mean_params["t0"])[:, jnp.newaxis] + sig1 = jnp.atleast_1d(mean_params["sig1"])[:, jnp.newaxis] + sig2 = jnp.atleast_1d(mean_params["sig2"])[:, jnp.newaxis] + + return jnp.sum( + A + * jnp.where( + t > t0, + jnp.exp(-(t - t0) / (2 * (sig2**2))), + jnp.exp((t - t0) / (2 * (sig1**2))), + ), + axis=0, ) @@ -228,16 +246,13 @@ def _fred(t, mean_params): ------- The y values for exponential flare. """ - return ( - mean_params["A"] - * jnp.exp( - -mean_params["phi"] - * ( - (t + mean_params["delta"]) / mean_params["t0"] - + mean_params["t0"] / (t + mean_params["delta"]) - ) - ) - * jnp.exp(2 * mean_params["phi"]) + A = jnp.atleast_1d(mean_params["A"])[:, jnp.newaxis] + t0 = jnp.atleast_1d(mean_params["t0"])[:, jnp.newaxis] + phi = jnp.atleast_1d(mean_params["phi"])[:, jnp.newaxis] + delta = jnp.atleast_1d(mean_params["delta"])[:, jnp.newaxis] + + return jnp.sum( + A * jnp.exp(-phi * ((t + delta) / t0 + t0 / (t + delta))) * jnp.exp(2 * phi), axis=0 ) diff --git a/stingray/modeling/tests/test_gpmodeling.py b/stingray/modeling/tests/test_gpmodeling.py index 19e2305dc..27ddadcec 100644 --- a/stingray/modeling/tests/test_gpmodeling.py +++ b/stingray/modeling/tests/test_gpmodeling.py @@ -39,22 +39,71 @@ def test_get_kernel_rn(self): class Testget_mean(object): def setup_class(self): self.t = np.linspace(0, 5, 10) - self.mean_params_gaussian = { + self.mean_params = { "A": jnp.array([3.0, 4.0]), "t0": jnp.array([0.2, 0.7]), "sig": jnp.array([0.2, 0.1]), } + self.skew_mean_params = { + "A": jnp.array([3.0, 4.0]), + "t0": jnp.array([0.2, 0.7]), + "sig1": jnp.array([0.2, 0.1]), + "sig2": jnp.array([0.3, 0.4]), + } + self.fred_mean_params = { + "A": jnp.array([3.0, 4.0]), + "t0": jnp.array([0.2, 0.7]), + "phi": jnp.array([4.0, 5.0]), + "delta": jnp.array([0.3, 0.4]), + } def test_get_mean_gaussian(self): result_gaussian = 3 * jnp.exp(-((self.t - 0.2) ** 2) / (2 * (0.2**2))) + 4 * jnp.exp( -((self.t - 0.7) ** 2) / (2 * (0.1**2)) ) - assert (get_mean("gaussian", self.mean_params_gaussian)(self.t) == result_gaussian).all() + assert (get_mean("gaussian", self.mean_params)(self.t) == result_gaussian).all() def test_get_mean_exponential(self): result_exponential = 3 * jnp.exp(-jnp.abs(self.t - 0.2) / (2 * (0.2**2))) + 4 * jnp.exp( -jnp.abs(self.t - 0.7) / (2 * (0.1**2)) ) + assert (get_mean("exponential", self.mean_params)(self.t) == result_exponential).all() + + def test_get_mean_constant(self): + result_constant = 3 * jnp.ones_like(self.t) + const_param_dict = {"A": jnp.array([3.0])} + assert (get_mean("constant", const_param_dict)(self.t) == result_constant).all() + + def test_get_mean_skew_gaussian(self): + result_skew_gaussian = 3.0 * jnp.where( + self.t > 0.2, + jnp.exp(-((self.t - 0.2) ** 2) / (2 * (0.3**2))), + jnp.exp(-((self.t - 0.2) ** 2) / (2 * (0.2**2))), + ) + 4.0 * jnp.where( + self.t > 0.7, + jnp.exp(-((self.t - 0.7) ** 2) / (2 * (0.4**2))), + jnp.exp(-((self.t - 0.7) ** 2) / (2 * (0.1**2))), + ) assert ( - get_mean("exponential", self.mean_params_gaussian)(self.t) == result_exponential + get_mean("skew_gaussian", self.skew_mean_params)(self.t) == result_skew_gaussian ).all() + + def test_get_mean_skew_exponential(self): + result_skew_exponential = 3.0 * jnp.where( + self.t > 0.2, + jnp.exp(-jnp.abs(self.t - 0.2) / (2 * (0.3**2))), + jnp.exp(-jnp.abs(self.t - 0.2) / (2 * (0.2**2))), + ) + 4.0 * jnp.where( + self.t > 0.7, + jnp.exp(-jnp.abs(self.t - 0.7) / (2 * (0.4**2))), + jnp.exp(-jnp.abs(self.t - 0.7) / (2 * (0.1**2))), + ) + assert ( + get_mean("skew_exponential", self.skew_mean_params)(self.t) == result_skew_exponential + ).all() + + def test_get_mean_fred(self): + result_fred = 3.0 * jnp.exp(-4.0 * ((self.t + 0.3) / 0.2 + 0.2 / (self.t + 0.3))) * jnp.exp( + 2 * 4.0 + ) + 4.0 * jnp.exp(-5.0 * ((self.t + 0.4) / 0.7 + 0.7 / (self.t + 0.4))) * jnp.exp(2 * 5.0) + assert (get_mean("fred", self.fred_mean_params)(self.t) == result_fred).all() From 2c296b1ae31301b9f976b7ac50a90bae5886f410 Mon Sep 17 00:00:00 2001 From: Gaurav17Joshi Date: Thu, 29 Jun 2023 14:55:38 +0530 Subject: [PATCH 07/50] Testing for get_pior_params --- stingray/modeling/tests/test_gpmodeling.py | 84 +++++++++++++++++++++- 1 file changed, 83 insertions(+), 1 deletion(-) diff --git a/stingray/modeling/tests/test_gpmodeling.py b/stingray/modeling/tests/test_gpmodeling.py index 27ddadcec..abc9632cf 100644 --- a/stingray/modeling/tests/test_gpmodeling.py +++ b/stingray/modeling/tests/test_gpmodeling.py @@ -5,7 +5,7 @@ import matplotlib.pyplot as plt from tinygp import GaussianProcess, kernels -from stingray.modeling.gpmodeling import get_kernel, get_mean, GP +from stingray.modeling.gpmodeling import get_kernel, get_mean, get_gp_params class Testget_kernel(object): @@ -107,3 +107,85 @@ def test_get_mean_fred(self): 2 * 4.0 ) + 4.0 * jnp.exp(-5.0 * ((self.t + 0.4) / 0.7 + 0.7 / (self.t + 0.4))) * jnp.exp(2 * 5.0) assert (get_mean("fred", self.fred_mean_params)(self.t) == result_fred).all() + + class Testget_gp_params(object): + def setup_class(self): + pass + + def test_get_gp_params_rn(self): + assert get_gp_params("RN", "gaussian") == ["arn", "crn", "A", "t0", "sig"] + assert get_gp_params("RN", "constant") == ["arn", "crn", "A"] + assert get_gp_params("RN", "skew_gaussian") == ["arn", "crn", "A", "t0", "sig1", "sig2"] + assert get_gp_params("RN", "skew_exponential") == [ + "arn", + "crn", + "A", + "t0", + "sig1", + "sig2", + ] + assert get_gp_params("RN", "exponential") == ["arn", "crn", "A", "t0", "sig"] + assert get_gp_params("RN", "fred") == ["arn", "crn", "A", "t0", "delta", "phi"] + + def test_get_gp_params_qpo_plus_rn(self): + assert get_gp_params("QPO_plus_RN", "gaussian") == [ + "arn", + "crn", + "aqpo", + "cqpo", + "freq", + "A", + "t0", + "sig", + ] + assert get_gp_params("QPO_plus_RN", "constant") == [ + "arn", + "crn", + "aqpo", + "cqpo", + "freq", + "A", + ] + assert get_gp_params("QPO_plus_RN", "skew_gaussian") == [ + "arn", + "crn", + "aqpo", + "cqpo", + "freq", + "A", + "t0", + "sig1", + "sig2", + ] + assert get_gp_params("QPO_plus_RN", "skew_exponential") == [ + "arn", + "crn", + "aqpo", + "cqpo", + "freq", + "A", + "t0", + "sig1", + "sig2", + ] + assert get_gp_params("QPO_plus_RN", "exponential") == [ + "arn", + "crn", + "aqpo", + "cqpo", + "freq", + "A", + "t0", + "sig", + ] + assert get_gp_params("QPO_plus_RN", "fred") == [ + "arn", + "crn", + "aqpo", + "cqpo", + "freq", + "A", + "t0", + "delta", + "phi", + ] From 014826b8a50bc79cfa7ae2099051047300795b8b Mon Sep 17 00:00:00 2001 From: Gaurav17Joshi Date: Tue, 11 Jul 2023 18:23:41 +0530 Subject: [PATCH 08/50] Changed the GP class --- stingray/modeling/gpmodeling.py | 216 +++++++++++++++++--------------- 1 file changed, 112 insertions(+), 104 deletions(-) diff --git a/stingray/modeling/gpmodeling.py b/stingray/modeling/gpmodeling.py index 2c9426060..b82803f44 100644 --- a/stingray/modeling/gpmodeling.py +++ b/stingray/modeling/gpmodeling.py @@ -15,13 +15,14 @@ # from jaxns import analytic_log_evidence from jaxns import Prior, Model +from jaxns.utils import resample jax.config.update("jax_enable_x64", True) tfpd = tfp.distributions tfpb = tfp.bijectors -__all__ = ["GP"] +__all__ = ["GPResult"] def get_kernel(kernel_type, kernel_params): @@ -349,37 +350,23 @@ def likelihood_model(*args): return likelihood_model -class GP: +class GPResult: """ - Makes a GP object which takes in a Stingray.Lightcurve and fits a Gaussian - Process on the lightcurve data, for the given kernel. + Makes a GPResult object which takes in a Stingray.Lightcurve and samples parameters of a model + (Gaussian Process) based on the given prior and log_likelihood function. Parameters ---------- lc: Stingray.Lightcurve object - The lightcurve on which the gaussian process, is to be fitted - - Model_type: string tuple - Has two strings with the first being the name of the kernel type - and the secound being the mean type - - Model_parameter: dict, default = None - Dictionary conatining the parameters for the mean and kernel - The keys should be accourding to the selected kernel and mean - coressponding to the Model_type - By default, it takes a value None, and the kernel and mean are - then bulit using the pre-set parameters. + The lightcurve on which the bayesian inference is to be done Other Parameters ---------------- - kernel: class: `TinyGp.kernel` object - The tinygp kernel for the GP + time : class: np.array + The array containing the times of the lightcurve - mean: class: `TinyGp.mean` object - The tinygp mean for the GP - - maingp: class: `TinyGp.GaussianProcess` object - The tinygp gaussian process made on the lightcurve + counts : class: np.array + The array containing the photon counts of the lightcurve """ @@ -387,77 +374,7 @@ def __init__(self, Lc: Lightcurve) -> None: self.lc = Lc self.time = Lc.time self.counts = Lc.counts - - def fit(self, kernel=None, mean=None, **kwargs): - self.kernel = kernel - self.mean = mean - self.maingp = GaussianProcess( - self.kernel, self.time, mean_value=self.mean(self.time), diag=kwargs["diag"] - ) - - def get_logprob(self): - """ - Returns the logprobability of the lightcurves counts for the - given kernel for the Gaussian Process - """ - cond = self.maingp.condition(self.lc.counts) - return cond.log_probability - - def plot_kernel(self): - """ - Plots the kernel of the Gaussian Process - """ - X = self.lc.time - Y = self.kernel(X, np.array([0.0])) - plt.plot(X, Y) - plt.xlabel("distance") - plt.ylabel("Value") - plt.title("Kernel Function") - - def plot_originalgp(self, sample_no=1, seed=0): - """ - Plots samples obtained from the gaussian process for the kernel - - Parameters - ---------- - sample_no: int , default = 1 - Number of GP samples to be taken - - """ - X_test = self.lc.time - _, ax = plt.subplots(1, 1, figsize=(10, 3)) - y_samp = self.maingp.sample(jax.random.PRNGKey(seed), shape=(sample_no,)) - ax.plot(X_test, y_samp[0], "C0", lw=0.5, alpha=0.5, label="samples") - ax.plot(X_test, y_samp[1:].T, "C0", lw=0.5, alpha=0.5) - ax.set_xlabel("time") - ax.set_ylabel("counts") - ax.legend(loc="best") - - def plot_gp(self, sample_no=1, seed=0): - """ - Plots gaussian process, conditioned on the lightcurve - Also, plots the lightcurve along with it - - Parameters - ---------- - sample_no: int , default = 1 - Number of GP samples to be taken - - """ - X_test = self.lc.time - - _, ax = plt.subplots(1, 1, figsize=(10, 3)) - _, cond_gp = self.maingp.condition(self.lc.counts, X_test) - mu = cond_gp.mean - # std = np.sqrt(cond_gp.variance) - - ax.plot(self.lc.time, self.lc.counts, lw=2, color="blue", label="Lightcurve") - ax.plot(X_test, mu, "C1", label="Gaussian Process") - y_samp = cond_gp.sample(jax.random.PRNGKey(seed), shape=(sample_no,)) - ax.plot(X_test, y_samp[0], "C0", lw=0.5, alpha=0.5) - ax.set_xlabel("time") - ax.set_ylabel("counts") - ax.legend(loc="best") + self.Result = None def sample(self, prior_model=None, likelihood_model=None, **kwargs): """ @@ -493,6 +410,12 @@ def sample(self, prior_model=None, likelihood_model=None, **kwargs): self.Results = self.Exact_ns.to_results(State, Termination_reason) print("Simulation Complete") + def get_evidence(self): + """ + Returns the log evidence of the model + """ + return self.Results.log_Z_mean + def print_summary(self): """ Prints a summary table for the model parameters @@ -511,24 +434,109 @@ def plot_cornerplot(self): """ self.Exact_ns.plot_cornerplot(self.Results) - def get_parameters(self): + def get_parameters_names(self): + """ + Returns the names of the parameters + """ + return sorted(self.Results.samples.keys()) + + def get_max_posterior_parameters(self): """ Returns the optimal parameters for the model based on the NUTS sampling """ + max_post_idx = jnp.argmax(self.Results.log_posterior_density) + map_points = jax.tree_map(lambda x: x[max_post_idx], self.Results.samples) - pass + return map_points - def plot_posterior(self, X_test): + def get_max_likelihood_parameters(self): + """ + Retruns the maximum likelihood parameters """ - Plots posterior gaussian process, conditioned on the lightcurve - Also, plots the lightcurve along with it + max_like_idx = jnp.argmax(self.Results.log_L_samples) + max_like_points = jax.tree_map(lambda x: x[max_like_idx], self.Results.samples) - Parameters - ---------- - X_test: jnp.array - Array over which the Gaussian process values are to be obtained - Can be made default with lc.times as default + return max_like_points + + def posterior_plot(self, name: str, n=0): + """ + Plots the posterior histogram for the given parameter + """ + nsamples = self.Results.total_num_samples + samples = self.Results.samples[name].reshape((nsamples, -1))[:, n] + plt.hist( + samples, bins="auto", density=True, alpha=1.0, label=name, fc="None", edgecolor="black" + ) + mean1 = jnp.mean(self.Results.samples[name]) + std1 = jnp.std(self.Results.samples[name]) + plt.axvline(mean1, color="red", linestyle="dashed", label="mean") + plt.axvline(mean1 + std1, color="green", linestyle="dotted") + plt.axvline(mean1 - std1, linestyle="dotted", color="green") + plt.legend() + plt.plot() + + pass + + def weighted_posterior_plot(self, name: str, n=0, rkey=random.PRNGKey(1234)): + """ + Returns the weighted posterior histogram for the given parameter + """ + nsamples = self.Results.total_num_samples + log_p = self.Results.log_dp_mean + samples = self.Results.samples[name].reshape((nsamples, -1))[:, n] + + weights = jnp.where(jnp.isfinite(samples), jnp.exp(log_p), 0.0) + log_weights = jnp.where(jnp.isfinite(samples), log_p, -jnp.inf) + samples_resampled = resample( + rkey, samples, log_weights, S=max(10, int(self.Results.ESS)), replace=True + ) + nbins = max(10, int(jnp.sqrt(self.Results.ESS)) + 1) + binsx = jnp.linspace(*jnp.percentile(samples_resampled, jnp.asarray([0, 100])), 2 * nbins) + + plt.hist( + np.asarray(samples_resampled), + bins=binsx, + density=True, + alpha=1.0, + label=name, + fc="None", + edgecolor="black", + ) + sample_mean = jnp.average(samples, weights=weights) + sample_std = jnp.sqrt(jnp.average((samples - sample_mean) ** 2, weights=weights)) + plt.axvline(sample_mean, color="red", linestyle="dashed", label="mean") + plt.axvline(sample_mean + sample_std, color="green", linestyle="dotted") + plt.axvline(sample_mean - sample_std, linestyle="dotted", color="green") + plt.legend() + plt.plot() + + def corner_plot(self, param1: str, param2: str, n1=0, n2=0, rkey=random.PRNGKey(1234)): """ + Plots the corner plot for the given parameters + """ + nsamples = self.Results.total_num_samples + log_p = self.Results.log_dp_mean + samples1 = self.Results.samples[param1].reshape((nsamples, -1))[:, n1] + samples2 = self.Results.samples[param2].reshape((nsamples, -1))[:, n2] + + log_weights = jnp.where(jnp.isfinite(samples2), log_p, -jnp.inf) + nbins = max(10, int(jnp.sqrt(self.Results.ESS)) + 1) + + samples_resampled = resample( + rkey, + jnp.stack([samples1, samples2], axis=-1), + log_weights, + S=max(10, int(self.Results.ESS)), + replace=True, + ) + plt.hist2d( + samples_resampled[:, 1], + samples_resampled[:, 0], + bins=(nbins, nbins), + density=True, + cmap="GnBu", + ) + plt.plot() pass From bb33727e4f10a41e289b1cf469a04f27786ff4fb Mon Sep 17 00:00:00 2001 From: Gaurav17Joshi Date: Tue, 18 Jul 2023 19:46:07 +0530 Subject: [PATCH 09/50] Improved library imports --- stingray/modeling/gpmodeling.py | 53 +++++++++++++++++++++++++-------- 1 file changed, 40 insertions(+), 13 deletions(-) diff --git a/stingray/modeling/gpmodeling.py b/stingray/modeling/gpmodeling.py index b82803f44..f9178ce37 100644 --- a/stingray/modeling/gpmodeling.py +++ b/stingray/modeling/gpmodeling.py @@ -1,26 +1,41 @@ import numpy as np import matplotlib.pyplot as plt -import jax -import jax.numpy as jnp import functools -import tensorflow_probability.substrates.jax as tfp +from stingray import Lightcurve + +try: + import jax +except ImportError: + raise ImportError("Jax not installed") +import jax.numpy as jnp from jax import jit, random -from tinygp import GaussianProcess, kernels -from stingray import Lightcurve +jax.config.update("jax_enable_x64", True) -from jaxns import ExactNestedSampler -from jaxns import TerminationCondition +try: + from tinygp import GaussianProcess, kernels -# from jaxns import analytic_log_evidence -from jaxns import Prior, Model -from jaxns.utils import resample + can_make_gp = True +except ImportError: + can_make_gp = False -jax.config.update("jax_enable_x64", True) +try: + from jaxns import ExactNestedSampler, TerminationCondition, Prior, Model + from jaxns.utils import resample + + can_sample = True +except ImportError: + can_sample = False +try: + import tensorflow_probability.substrates.jax as tfp + + tfpd = tfp.distributions + tfpb = tfp.bijectors + tfp_available = True +except ImportError: + tfp_available = False -tfpd = tfp.distributions -tfpb = tfp.bijectors __all__ = ["GPResult"] @@ -41,6 +56,9 @@ def get_kernel(kernel_type, kernel_params): Should contain the parameters for the selected kernel """ + if not can_make_gp: + raise ImportError("Tinygp is required to make kernels") + if kernel_type == "QPO_plus_RN": kernel = kernels.quasisep.Exp( scale=1 / kernel_params["crn"], sigma=(kernel_params["arn"]) ** 0.5 @@ -303,6 +321,11 @@ def get_prior(params_list, prior_dict): Non Windowed arguments """ + if not can_sample: + raise ImportError("Jaxns not installed. Cannot make jaxns specific prior.") + + if not tfp_available: + raise ImportError("Tensorflow probability required to make priors.") def prior_model(): prior_list = [] @@ -336,6 +359,8 @@ def get_likelihood(params_list, kernel_type, mean_type, **kwargs): The type of mean to be used in the model. """ + if not can_make_gp: + raise ImportError("Tinygp is required to make the GP model.") @jit def likelihood_model(*args): @@ -396,6 +421,8 @@ def sample(self, prior_model=None, likelihood_model=None, **kwargs): The results of the nested sampling process """ + if not can_sample: + raise ImportError("Jaxns not installed! Can't sample!") self.prior_model = prior_model self.likelihood_model = likelihood_model From 596c63b8fc20a58265bb05ecff05d897b3571ef7 Mon Sep 17 00:00:00 2001 From: Gaurav17Joshi Date: Tue, 18 Jul 2023 21:59:42 +0530 Subject: [PATCH 10/50] Primary tests for GPresult class --- stingray/modeling/tests/test_gpmodeling.py | 77 +++++++++++++++++++++- 1 file changed, 76 insertions(+), 1 deletion(-) diff --git a/stingray/modeling/tests/test_gpmodeling.py b/stingray/modeling/tests/test_gpmodeling.py index abc9632cf..f9949c3a6 100644 --- a/stingray/modeling/tests/test_gpmodeling.py +++ b/stingray/modeling/tests/test_gpmodeling.py @@ -1,11 +1,22 @@ import jax import jax.numpy as jnp +from jax import random + +jax.config.update("jax_enable_x64", True) + import numpy as np -import tensorflow_probability.substrates.jax as tfp import matplotlib.pyplot as plt from tinygp import GaussianProcess, kernels from stingray.modeling.gpmodeling import get_kernel, get_mean, get_gp_params +from stingray.modeling.gpmodeling import get_prior, get_likelihood, GPResult +from stingray import Lightcurve + +import tensorflow_probability.substrates.jax as tfp + +tfpd = tfp.distributions + +from jaxns import ExactNestedSampler, TerminationCondition, Prior, Model class Testget_kernel(object): @@ -189,3 +200,67 @@ def test_get_gp_params_qpo_plus_rn(self): "delta", "phi", ] + + +class TestGPResult(object): + def setup_class(self): + self.Times = np.linspace(0, 1, 64) + kernel_params = { + "arn": jnp.exp(1.5), + "crn": jnp.exp(1.0), + } + mean_params = {"A": jnp.array([3.0]), "t0": jnp.array([0.2]), "sig": jnp.array([0.2])} + kernel = get_kernel("RN", kernel_params) + mean = get_mean("gaussian", mean_params) + + gp = GaussianProcess(kernel=kernel, X=self.Times, mean_value=mean(self.Times)) + self.counts = gp.sample(key=jax.random.PRNGKey(6)) + + lc = Lightcurve(time=self.Times, counts=self.counts, dt=self.Times[1] - self.Times[0]) + + self.params_list = get_gp_params(kernel_type="RN", mean_type="gaussian") + + T = self.Times[-1] - self.Times[0] + f = 1 / (self.Times[1] - self.Times[0]) + span = jnp.max(self.counts) - jnp.min(self.counts) + + # The prior dictionary, with suitable tfpd prior distributions + prior_dict = { + "A": tfpd.Uniform(low=0.1 * span, high=2 * span), + "t0": tfpd.Uniform(low=self.Times[0] - 0.1 * T, high=self.Times[-1] + 0.1 * T), + "sig": tfpd.Uniform(low=0.5 * 1 / f, high=2 * T), + "arn": tfpd.Uniform(low=0.1 * span, high=2 * span), + "crn": tfpd.Uniform(low=jnp.log(1 / T), high=jnp.log(f)), + } + + prior_model = get_prior(self.params_list, prior_dict) + likelihood_model = get_likelihood( + self.params_list, + kernel_type="RN", + mean_type="gaussian", + Times=self.Times, + counts=self.counts, + ) + + NSmodel = Model(prior_model=prior_model, log_likelihood=likelihood_model) + NSmodel.sanity_check(random.PRNGKey(10), S=100) + + Exact_ns = ExactNestedSampler(NSmodel, num_live_points=500, max_samples=1e4) + Termination_reason, State = Exact_ns( + random.PRNGKey(42), term_cond=TerminationCondition(live_evidence_frac=1e-4) + ) + self.Results = Exact_ns.to_results(State, Termination_reason) + + self.gpresult = GPResult(lc) + self.gpresult.sample(prior_model=prior_model, likelihood_model=likelihood_model) + + def test_sample(self): + for key in self.params_list: + assert (self.Results.samples[key]).all() == (self.gpresult.Results.samples[key]).all() + + def test_get_evidence(self): + assert self.Results.log_Z_mean == self.gpresult.Results.log_Z_mean + + def plot_diagnostics(self): + self.gpresult.plot_diagnostics() + assert plt.fignum_exists(1) From 01b52a3f2317e48b1d31dfdb8d36ebc5ce2709e8 Mon Sep 17 00:00:00 2001 From: Gaurav17Joshi Date: Wed, 19 Jul 2023 19:20:25 +0530 Subject: [PATCH 11/50] Added docstrings --- stingray/modeling/gpmodeling.py | 112 ++++++++++++++++++--- stingray/modeling/tests/test_gpmodeling.py | 11 +- 2 files changed, 106 insertions(+), 17 deletions(-) diff --git a/stingray/modeling/gpmodeling.py b/stingray/modeling/gpmodeling.py index f9178ce37..38fadbdef 100644 --- a/stingray/modeling/gpmodeling.py +++ b/stingray/modeling/gpmodeling.py @@ -43,7 +43,7 @@ def get_kernel(kernel_type, kernel_params): """ Function for producing the kernel for the Gaussian Process. - Returns the selected Tinygp kernel + Returns the selected Tinygp kernel for the given parameters. Parameters ---------- @@ -275,14 +275,38 @@ def _fred(t, mean_params): ) -def get_kernel_params(kernel_type): +def _get_kernel_params(kernel_type): + """ + Generates a list of the parameters for the kernel for the GP model based on the kernel type. + + Parameters + ---------- + kernel_type: string + The type of kernel to be used for the Gaussian Process model + + Returns + ------- + A list of the parameters for the kernel for the GP model + """ if kernel_type == "RN": return ["arn", "crn"] elif kernel_type == "QPO_plus_RN": return ["arn", "crn", "aqpo", "cqpo", "freq"] -def get_mean_params(mean_type): +def _get_mean_params(mean_type): + """ + Generates a list of the parameters for the mean for the GP model based on the mean type. + + Parameters + ---------- + mean_type: string + The type of mean to be used for the Gaussian Process model + + Returns + ------- + A list of the parameters for the mean for the GP model + """ if (mean_type == "gaussian") or (mean_type == "exponential"): return ["A", "t0", "sig"] elif mean_type == "constant": @@ -294,15 +318,39 @@ def get_mean_params(mean_type): def get_gp_params(kernel_type, mean_type): - kernel_params = get_kernel_params(kernel_type) - mean_params = get_mean_params(mean_type) + """ + Generates a list of the parameters for the GP model based on the kernel and mean type. + To be used to set the order of the parameters for `get_prior` and `get_likelihood` functions. + + Parameters + ---------- + kernel_type: string + The type of kernel to be used for the Gaussian Process model + + mean_type: string + The type of mean to be used for the Gaussian Process model + + Returns + ------- + A list of the parameters for the GP model + + Examples + -------- + >>> get_gp_params("QPO_plus_RN", "gaussian") + ['arn', 'crn', 'aqpo', 'cqpo', 'freq', 'A', 't0', 'sig'] + """ + kernel_params = _get_kernel_params(kernel_type) + mean_params = _get_mean_params(mean_type) kernel_params.extend(mean_params) return kernel_params def get_prior(params_list, prior_dict): """ - A prior generator function based on given values + A prior generator function based on given values. + Makes a jaxns specific prior function based on the given prior dictionary. + Jaxns requires the parameters of the prior function and log_likelihood function to + be in the same order. This order is made according to the params_list. Parameters ---------- @@ -311,14 +359,35 @@ def get_prior(params_list, prior_dict): prior_dict: A dictionary of the priors of parameters to be used. + These parameters should be from tensorflow_probability distributions / Priors from jaxns + or special priors from jaxns. + **Note**: If jaxns priors are used, then the name given to them should be the same as + the corresponding name in the params_list. Returns ------- - The Prior function. + The Prior generator function. The arguments of the prior function are in the order of - Kernel arguments (RN arguments, QPO arguments), - Mean arguments - Non Windowed arguments + Kernel arguments (RN arguments, QPO arguments), + Mean arguments + Miscellaneous arguments + + Examples + -------- + A prior function for a Red Noise kernel and a Gaussian mean function + Obain the parameters list + >>> params_list = get_gp_params("RN", "gaussian") + + Make a prior dictionary using tensorflow_probability distributions + >>> prior_dict = { + ... "A": tfpd.Uniform(low = 1e-1, high = 2e+2), + ... "t0": tfpd.Uniform(low = 0.0 - 0.1, high = 1 + 0.1), + ... "sig": tfpd.Uniform(low = 0.5 * 1 / 20, high = 2 ), + ... "arn": tfpd.Uniform(low = 0.1 , high = 2 ), + ... "crn": tfpd.Uniform(low = jnp.log(1 /5), high = jnp.log(20)), + ... } + + >>> prior_model = get_prior(params_list, prior_dict) """ if not can_sample: @@ -342,7 +411,11 @@ def prior_model(): def get_likelihood(params_list, kernel_type, mean_type, **kwargs): """ - A likelihood generator function based on given values + A log likelihood generator function based on given values. + Makes a jaxns specific log likelihood function which takes in the + parameters in the order of the parameters list, and calculates the + log likelihood of the data given the parameters, and the model + (kernel, mean) of the GP model. Parameters ---------- @@ -358,6 +431,19 @@ def get_likelihood(params_list, kernel_type, mean_type, **kwargs): mean_type: The type of mean to be used in the model. + **kwargs: + The keyword arguments to be used in the log likelihood function. + **Note**: The keyword arguments Times and counts are necessary for + calculating the log likelihood. + Times: np.array or jnp.array + The time array of the lightcurve + counts: np.array or jnp.array + The photon counts array of the lightcurve + + Returns + ------- + The jaxns specific log likelihood function. + """ if not can_make_gp: raise ImportError("Tinygp is required to make the GP model.") @@ -502,8 +588,6 @@ def posterior_plot(self, name: str, n=0): plt.legend() plt.plot() - pass - def weighted_posterior_plot(self, name: str, n=0, rkey=random.PRNGKey(1234)): """ Returns the weighted posterior histogram for the given parameter @@ -565,5 +649,3 @@ def corner_plot(self, param1: str, param2: str, n1=0, n2=0, rkey=random.PRNGKey( cmap="GnBu", ) plt.plot() - - pass diff --git a/stingray/modeling/tests/test_gpmodeling.py b/stingray/modeling/tests/test_gpmodeling.py index f9949c3a6..01ee0a398 100644 --- a/stingray/modeling/tests/test_gpmodeling.py +++ b/stingray/modeling/tests/test_gpmodeling.py @@ -259,8 +259,15 @@ def test_sample(self): assert (self.Results.samples[key]).all() == (self.gpresult.Results.samples[key]).all() def test_get_evidence(self): - assert self.Results.log_Z_mean == self.gpresult.Results.log_Z_mean + assert self.Results.log_Z_mean == self.gpresult.get_evidence() - def plot_diagnostics(self): + def test_plot_diagnostics(self): self.gpresult.plot_diagnostics() assert plt.fignum_exists(1) + + def test_plot_cornerplot(self): + self.gpresult.plot_cornerplot() + assert plt.fignum_exists(1) + + def test_get_parameters_names(self): + assert sorted(self.params_list) == self.gpresult.get_parameters_names() From 6069b4c762e0004c21abe4e57558953a756d006d Mon Sep 17 00:00:00 2001 From: Gaurav17Joshi Date: Thu, 20 Jul 2023 02:50:37 +0530 Subject: [PATCH 12/50] Plot function changed, tests added --- stingray/modeling/gpmodeling.py | 159 ++++++++++++++++++++- stingray/modeling/tests/test_gpmodeling.py | 93 +++++++++++- 2 files changed, 242 insertions(+), 10 deletions(-) diff --git a/stingray/modeling/gpmodeling.py b/stingray/modeling/gpmodeling.py index 38fadbdef..d2e7f2917 100644 --- a/stingray/modeling/gpmodeling.py +++ b/stingray/modeling/gpmodeling.py @@ -571,9 +571,37 @@ def get_max_likelihood_parameters(self): return max_like_points - def posterior_plot(self, name: str, n=0): + def posterior_plot(self, name: str, n=0, axis=None, save=False, filename=None): """ Plots the posterior histogram for the given parameter + + Parameters + ---------- + name : str + Name of the parameter. + Should be from the names of the parameter list used or from the names of parameters + used in the prior_function + + n : int, default 0 + The index of the parameter to be plotted. + For multivariate parameters, the index of the specific parameter to be plotted. + + axis : list, tuple, string, default ``None`` + Parameter to set axis properties of ``matplotlib`` figure. For example + it can be a list like ``[xmin, xmax, ymin, ymax]`` or any other + acceptable argument for ``matplotlib.pyplot.axis()`` method. + + save : bool, optionalm, default ``False`` + If ``True``, save the figure with specified filename. + + filename : str + File name and path of the image to save. Depends on the boolean ``save``. + + Returns + ------- + plt : ``matplotlib.pyplot`` object + Reference to plot, call ``show()`` to display it + """ nsamples = self.Results.total_num_samples samples = self.Results.samples[name].reshape((nsamples, -1))[:, n] @@ -585,12 +613,56 @@ def posterior_plot(self, name: str, n=0): plt.axvline(mean1, color="red", linestyle="dashed", label="mean") plt.axvline(mean1 + std1, color="green", linestyle="dotted") plt.axvline(mean1 - std1, linestyle="dotted", color="green") + plt.title("Posterior Histogram of " + str(name)) + plt.xlabel(name) + plt.ylabel("Probability Density") plt.legend() - plt.plot() - def weighted_posterior_plot(self, name: str, n=0, rkey=random.PRNGKey(1234)): + if axis is not None: + plt.axis(axis) + + if save: + if filename is None: + plt.savefig(str(name) + "_Posterior_plot.png") + else: + plt.savefig(filename) + return plt + + def weighted_posterior_plot( + self, name: str, n=0, rkey=random.PRNGKey(1234), axis=None, save=False, filename=None + ): """ Returns the weighted posterior histogram for the given parameter + + Parameters + ---------- + name : str + Name of the parameter. + Should be from the names of the parameter list used or from the names of parameters + used in the prior_function + + n : int, default 0 + The index of the parameter to be plotted. + For multivariate parameters, the index of the specific parameter to be plotted. + + key: jax.random.PRNGKey, default ``random.PRNGKey(1234)`` + Random key for the weighted sampling + + axis : list, tuple, string, default ``None`` + Parameter to set axis properties of ``matplotlib`` figure. For example + it can be a list like ``[xmin, xmax, ymin, ymax]`` or any other + acceptable argument for ``matplotlib.pyplot.axis()`` method. + + save : bool, optionalm, default ``False`` + If ``True``, save the figure with specified filename. + + filename : str + File name and path of the image to save. Depends on the boolean ``save``. + + Returns + ------- + plt : ``matplotlib.pyplot`` object + Reference to plot, call ``show()`` to display it """ nsamples = self.Results.total_num_samples log_p = self.Results.log_dp_mean @@ -619,12 +691,72 @@ def weighted_posterior_plot(self, name: str, n=0, rkey=random.PRNGKey(1234)): plt.axvline(sample_mean, color="red", linestyle="dashed", label="mean") plt.axvline(sample_mean + sample_std, color="green", linestyle="dotted") plt.axvline(sample_mean - sample_std, linestyle="dotted", color="green") + plt.title("Weighted Posterior Histogram of " + str(name)) + plt.xlabel(name) + plt.ylabel("Probability Density") plt.legend() - plt.plot() + if axis is not None: + plt.axis(axis) - def corner_plot(self, param1: str, param2: str, n1=0, n2=0, rkey=random.PRNGKey(1234)): + if save: + if filename is None: + plt.savefig(str(name) + "_Weighted_Posterior_plot.png") + else: + plt.savefig(filename) + return plt + + def corner_plot( + self, + param1: str, + param2: str, + n1=0, + n2=0, + rkey=random.PRNGKey(1234), + axis=None, + save=False, + filename=None, + ): """ - Plots the corner plot for the given parameters + Plots the corner plot between two given parameters + + Parameters + ---------- + param1 : str + Name of the first parameter. + Should be from the names of the parameter list used or from the names of parameters + used in the prior_function + + param2 : str + Name of the second parameter. + Should be from the names of the parameter list used or from the names of parameters + used in the prior_function + + n1 : int, default 0 + The index of the first parameter to be plotted. + For multivariate parameters, the index of the specific parameter to be plotted. + + n2 : int, default 0 + The index of the second parameter to be plotted. + For multivariate parameters, the index of the specific parameter to be plotted. + + key: jax.random.PRNGKey, default ``random.PRNGKey(1234)`` + Random key for the shuffling the weights + + axis : list, tuple, string, default ``None`` + Parameter to set axis properties of ``matplotlib`` figure. For example + it can be a list like ``[xmin, xmax, ymin, ymax]`` or any other + acceptable argument for ``matplotlib.pyplot.axis()`` method. + + save : bool, optionalm, default ``False`` + If ``True``, save the figure with specified filename. + + filename : str + File name and path of the image to save. Depends on the boolean ``save``. + + Returns + ------- + plt : ``matplotlib.pyplot`` object + Reference to plot, call ``show()`` to display it """ nsamples = self.Results.total_num_samples log_p = self.Results.log_dp_mean @@ -648,4 +780,17 @@ def corner_plot(self, param1: str, param2: str, n1=0, n2=0, rkey=random.PRNGKey( density=True, cmap="GnBu", ) - plt.plot() + plt.title("Corner Plot of " + str(param1) + " and " + str(param2)) + plt.xlabel(param2) + plt.ylabel(param1) + plt.colorbar() + if axis is not None: + plt.axis(axis) + + if save: + if filename is None: + plt.savefig(str(param1) + "_" + str(param2) + "_Corner_plot.png") + else: + plt.savefig(filename) + + return plt diff --git a/stingray/modeling/tests/test_gpmodeling.py b/stingray/modeling/tests/test_gpmodeling.py index 01ee0a398..992e34045 100644 --- a/stingray/modeling/tests/test_gpmodeling.py +++ b/stingray/modeling/tests/test_gpmodeling.py @@ -1,12 +1,13 @@ +import os +import numpy as np +import matplotlib.pyplot as plt + import jax import jax.numpy as jnp from jax import random jax.config.update("jax_enable_x64", True) -import numpy as np -import matplotlib.pyplot as plt - from tinygp import GaussianProcess, kernels from stingray.modeling.gpmodeling import get_kernel, get_mean, get_gp_params from stingray.modeling.gpmodeling import get_prior, get_likelihood, GPResult @@ -19,6 +20,12 @@ from jaxns import ExactNestedSampler, TerminationCondition, Prior, Model +def clear_all_figs(): + fign = plt.get_fignums() + for fig in fign: + plt.close(fig) + + class Testget_kernel(object): def setup_class(self): self.x = np.linspace(0, 1, 5) @@ -271,3 +278,83 @@ def test_plot_cornerplot(self): def test_get_parameters_names(self): assert sorted(self.params_list) == self.gpresult.get_parameters_names() + + def test_print_summary(self): + self.gpresult.print_summary() + assert True + + def test_max_posterior_parameters(self): + for key in self.params_list: + assert key in self.gpresult.get_max_posterior_parameters() + + def test_max_likelihood_parameters(self): + for key in self.params_list: + assert key in self.gpresult.get_max_likelihood_parameters() + + def test_posterior_plot(self): + self.gpresult.posterior_plot("A") + assert plt.fignum_exists(1) + + def test_posterior_plot_labels_and_fname_default(self): + clear_all_figs() + outfname = "A_Posterior_plot.png" + if os.path.exists(outfname): + os.unlink(outfname) + self.gpresult.posterior_plot("A", save=True) + assert os.path.exists(outfname) + os.unlink(outfname) + + def test_posterior_plot_labels_and_fname(self): + clear_all_figs() + outfname = "blabla.png" + if os.path.exists(outfname): + os.unlink(outfname) + self.gpresult.posterior_plot("A", axis=[0, 14, 0, 0.5], save=True, filename=outfname) + assert os.path.exists(outfname) + os.unlink(outfname) + + def test_weighted_posterior_plot(self): + self.gpresult.weighted_posterior_plot("A") + assert plt.fignum_exists(1) + + def test_weighted_posterior_plot_labels_and_fname_default(self): + clear_all_figs() + outfname = "A_Weighted_Posterior_plot.png" + if os.path.exists(outfname): + os.unlink(outfname) + self.gpresult.weighted_posterior_plot("A", save=True) + assert os.path.exists(outfname) + os.unlink(outfname) + + def test_weighted_posterior_plot_labels_and_fname(self): + clear_all_figs() + outfname = "blabla.png" + if os.path.exists(outfname): + os.unlink(outfname) + self.gpresult.weighted_posterior_plot( + "A", axis=[0, 14, 0, 0.5], save=True, filename=outfname + ) + assert os.path.exists(outfname) + os.unlink(outfname) + + def test_corner_plot(self): + self.gpresult.corner_plot("A", "t0") + assert plt.fignum_exists(1) + + def test_corner_plot_labels_and_fname_default(self): + clear_all_figs() + outfname = "A_t0_Corner_plot.png" + if os.path.exists(outfname): + os.unlink(outfname) + self.gpresult.corner_plot("A", "t0", save=True) + assert os.path.exists(outfname) + os.unlink(outfname) + + def test_corner_plot_labels_and_fname(self): + clear_all_figs() + outfname = "blabla.png" + if os.path.exists(outfname): + os.unlink(outfname) + self.gpresult.corner_plot("A", "t0", axis=[0, 0.5, 0, 5], save=True, filename=outfname) + assert os.path.exists(outfname) + os.unlink(outfname) From 2737824d63bcc6b6b44a014b1d782003ae74e5cf Mon Sep 17 00:00:00 2001 From: Gaurav17Joshi Date: Wed, 26 Jul 2023 20:40:31 +0530 Subject: [PATCH 13/50] Added testing skips --- setup.cfg | 12 +- stingray/modeling/gpmodeling.py | 17 +- stingray/modeling/tests/test_gpmodeling.py | 205 ++++++++++++--------- 3 files changed, 134 insertions(+), 100 deletions(-) diff --git a/setup.cfg b/setup.cfg index e5ac3a72f..9ccd87d24 100644 --- a/setup.cfg +++ b/setup.cfg @@ -22,12 +22,6 @@ install_requires = scipy>=1.1.0 ; Matplotlib 3.4.0 is incompatible with Astropy matplotlib>=3.0,!=3.4.0 - jax - tinygp - jaxns - etils - tensorflow_probability - typing_extensions [options.entry_points] console_scripts = @@ -51,6 +45,12 @@ all = xarray pandas ultranest + jax + tinygp + jaxns + etils + tensorflow_probability + typing_extensions test = pytest pytest-astropy diff --git a/stingray/modeling/gpmodeling.py b/stingray/modeling/gpmodeling.py index d2e7f2917..6841cc57b 100644 --- a/stingray/modeling/gpmodeling.py +++ b/stingray/modeling/gpmodeling.py @@ -1,3 +1,4 @@ +import pytest import numpy as np import matplotlib.pyplot as plt import functools @@ -6,12 +7,15 @@ try: import jax except ImportError: - raise ImportError("Jax not installed") + pytest.skip(allow_module_level=True) -import jax.numpy as jnp -from jax import jit, random +try: + import jax.numpy as jnp + from jax import jit, random -jax.config.update("jax_enable_x64", True) + jax.config.update("jax_enable_x64", True) +except ImportError: + raise ImportError("Jax not installed") try: from tinygp import GaussianProcess, kernels @@ -376,6 +380,11 @@ def get_prior(params_list, prior_dict): -------- A prior function for a Red Noise kernel and a Gaussian mean function Obain the parameters list + >>> if not can_sample: + ... pytest.skip("Jaxns not installed. Cannot make jaxns specific prior.") + >>> if not tfp_available: + ... pytest.skip("Tensorflow probability required to make priors.") + >>> params_list = get_gp_params("RN", "gaussian") Make a prior dictionary using tensorflow_probability distributions diff --git a/stingray/modeling/tests/test_gpmodeling.py b/stingray/modeling/tests/test_gpmodeling.py index 992e34045..74af0b2a0 100644 --- a/stingray/modeling/tests/test_gpmodeling.py +++ b/stingray/modeling/tests/test_gpmodeling.py @@ -1,23 +1,43 @@ import os +import pytest import numpy as np import matplotlib.pyplot as plt -import jax -import jax.numpy as jnp -from jax import random +try: + import jax + import jax.numpy as jnp + from jax import random -jax.config.update("jax_enable_x64", True) + jax.config.update("jax_enable_x64", True) +except ImportError: + pytest.skip(allow_module_level=True) + +_HAS_TINYGP = True +_HAS_TFP = True +_HAS_JAXNS = True + +try: + import tinygp + from tinygp import GaussianProcess, kernels +except ImportError: + _HAS_TINYGP = False -from tinygp import GaussianProcess, kernels from stingray.modeling.gpmodeling import get_kernel, get_mean, get_gp_params from stingray.modeling.gpmodeling import get_prior, get_likelihood, GPResult from stingray import Lightcurve -import tensorflow_probability.substrates.jax as tfp +try: + import tensorflow_probability.substrates.jax as tfp -tfpd = tfp.distributions + tfpd = tfp.distributions +except ImportError: + _HAS_TFP = False -from jaxns import ExactNestedSampler, TerminationCondition, Prior, Model +try: + import jaxns + from jaxns import ExactNestedSampler, TerminationCondition, Prior, Model +except ImportError: + _HAS_JAXNS = False def clear_all_figs(): @@ -26,6 +46,7 @@ def clear_all_figs(): plt.close(fig) +@pytest.mark.skipif(not _HAS_TINYGP, reason="tinygp not installed") class Testget_kernel(object): def setup_class(self): self.x = np.linspace(0, 1, 5) @@ -126,89 +147,93 @@ def test_get_mean_fred(self): ) + 4.0 * jnp.exp(-5.0 * ((self.t + 0.4) / 0.7 + 0.7 / (self.t + 0.4))) * jnp.exp(2 * 5.0) assert (get_mean("fred", self.fred_mean_params)(self.t) == result_fred).all() - class Testget_gp_params(object): - def setup_class(self): - pass - - def test_get_gp_params_rn(self): - assert get_gp_params("RN", "gaussian") == ["arn", "crn", "A", "t0", "sig"] - assert get_gp_params("RN", "constant") == ["arn", "crn", "A"] - assert get_gp_params("RN", "skew_gaussian") == ["arn", "crn", "A", "t0", "sig1", "sig2"] - assert get_gp_params("RN", "skew_exponential") == [ - "arn", - "crn", - "A", - "t0", - "sig1", - "sig2", - ] - assert get_gp_params("RN", "exponential") == ["arn", "crn", "A", "t0", "sig"] - assert get_gp_params("RN", "fred") == ["arn", "crn", "A", "t0", "delta", "phi"] - - def test_get_gp_params_qpo_plus_rn(self): - assert get_gp_params("QPO_plus_RN", "gaussian") == [ - "arn", - "crn", - "aqpo", - "cqpo", - "freq", - "A", - "t0", - "sig", - ] - assert get_gp_params("QPO_plus_RN", "constant") == [ - "arn", - "crn", - "aqpo", - "cqpo", - "freq", - "A", - ] - assert get_gp_params("QPO_plus_RN", "skew_gaussian") == [ - "arn", - "crn", - "aqpo", - "cqpo", - "freq", - "A", - "t0", - "sig1", - "sig2", - ] - assert get_gp_params("QPO_plus_RN", "skew_exponential") == [ - "arn", - "crn", - "aqpo", - "cqpo", - "freq", - "A", - "t0", - "sig1", - "sig2", - ] - assert get_gp_params("QPO_plus_RN", "exponential") == [ - "arn", - "crn", - "aqpo", - "cqpo", - "freq", - "A", - "t0", - "sig", - ] - assert get_gp_params("QPO_plus_RN", "fred") == [ - "arn", - "crn", - "aqpo", - "cqpo", - "freq", - "A", - "t0", - "delta", - "phi", - ] - +class Testget_gp_params(object): + def setup_class(self): + pass + + def test_get_gp_params_rn(self): + assert get_gp_params("RN", "gaussian") == ["arn", "crn", "A", "t0", "sig"] + assert get_gp_params("RN", "constant") == ["arn", "crn", "A"] + assert get_gp_params("RN", "skew_gaussian") == ["arn", "crn", "A", "t0", "sig1", "sig2"] + assert get_gp_params("RN", "skew_exponential") == [ + "arn", + "crn", + "A", + "t0", + "sig1", + "sig2", + ] + assert get_gp_params("RN", "exponential") == ["arn", "crn", "A", "t0", "sig"] + assert get_gp_params("RN", "fred") == ["arn", "crn", "A", "t0", "delta", "phi"] + + def test_get_gp_params_qpo_plus_rn(self): + assert get_gp_params("QPO_plus_RN", "gaussian") == [ + "arn", + "crn", + "aqpo", + "cqpo", + "freq", + "A", + "t0", + "sig", + ] + assert get_gp_params("QPO_plus_RN", "constant") == [ + "arn", + "crn", + "aqpo", + "cqpo", + "freq", + "A", + ] + assert get_gp_params("QPO_plus_RN", "skew_gaussian") == [ + "arn", + "crn", + "aqpo", + "cqpo", + "freq", + "A", + "t0", + "sig1", + "sig2", + ] + assert get_gp_params("QPO_plus_RN", "skew_exponential") == [ + "arn", + "crn", + "aqpo", + "cqpo", + "freq", + "A", + "t0", + "sig1", + "sig2", + ] + assert get_gp_params("QPO_plus_RN", "exponential") == [ + "arn", + "crn", + "aqpo", + "cqpo", + "freq", + "A", + "t0", + "sig", + ] + assert get_gp_params("QPO_plus_RN", "fred") == [ + "arn", + "crn", + "aqpo", + "cqpo", + "freq", + "A", + "t0", + "delta", + "phi", + ] + + +@pytest.mark.skipif( + not (_HAS_TINYGP and _HAS_TFP and _HAS_JAXNS), reason="tinygp, tfp or jaxns not installed" +) class TestGPResult(object): def setup_class(self): self.Times = np.linspace(0, 1, 64) From e9036f5ead3251a436540db7aee4faaed3af2cb2 Mon Sep 17 00:00:00 2001 From: Gaurav17Joshi Date: Mon, 14 Aug 2023 15:38:39 +0530 Subject: [PATCH 14/50] Jax import changed --- stingray/modeling/gpmodeling.py | 63 +++++++++++++++++++++------------ 1 file changed, 40 insertions(+), 23 deletions(-) diff --git a/stingray/modeling/gpmodeling.py b/stingray/modeling/gpmodeling.py index 6841cc57b..40977280e 100644 --- a/stingray/modeling/gpmodeling.py +++ b/stingray/modeling/gpmodeling.py @@ -1,4 +1,3 @@ -import pytest import numpy as np import matplotlib.pyplot as plt import functools @@ -6,16 +5,13 @@ try: import jax -except ImportError: - pytest.skip(allow_module_level=True) - -try: import jax.numpy as jnp from jax import jit, random jax.config.update("jax_enable_x64", True) + jax_avail = True except ImportError: - raise ImportError("Jax not installed") + jax_avail = False try: from tinygp import GaussianProcess, kernels @@ -60,6 +56,9 @@ def get_kernel(kernel_type, kernel_params): Should contain the parameters for the selected kernel """ + if not jax_avail: + raise ImportError("Jax is required") + if not can_make_gp: raise ImportError("Tinygp is required to make kernels") @@ -95,6 +94,9 @@ def get_mean(mean_type, mean_params): Should contain the parameters for the selected mean """ + if not jax_avail: + raise ImportError("Jax is required") + if mean_type == "gaussian": mean = functools.partial(_gaussian, mean_params=mean_params) elif mean_type == "exponential": @@ -340,7 +342,7 @@ def get_gp_params(kernel_type, mean_type): Examples -------- - >>> get_gp_params("QPO_plus_RN", "gaussian") + get_gp_params("QPO_plus_RN", "gaussian") ['arn', 'crn', 'aqpo', 'cqpo', 'freq', 'A', 't0', 'sig'] """ kernel_params = _get_kernel_params(kernel_type) @@ -380,25 +382,28 @@ def get_prior(params_list, prior_dict): -------- A prior function for a Red Noise kernel and a Gaussian mean function Obain the parameters list - >>> if not can_sample: - ... pytest.skip("Jaxns not installed. Cannot make jaxns specific prior.") - >>> if not tfp_available: - ... pytest.skip("Tensorflow probability required to make priors.") + if not can_sample: + pytest.skip("Jaxns not installed. Cannot make jaxns specific prior.") + if not tfp_available: + pytest.skip("Tensorflow probability required to make priors.") - >>> params_list = get_gp_params("RN", "gaussian") + params_list = get_gp_params("RN", "gaussian") Make a prior dictionary using tensorflow_probability distributions - >>> prior_dict = { - ... "A": tfpd.Uniform(low = 1e-1, high = 2e+2), - ... "t0": tfpd.Uniform(low = 0.0 - 0.1, high = 1 + 0.1), - ... "sig": tfpd.Uniform(low = 0.5 * 1 / 20, high = 2 ), - ... "arn": tfpd.Uniform(low = 0.1 , high = 2 ), - ... "crn": tfpd.Uniform(low = jnp.log(1 /5), high = jnp.log(20)), - ... } + prior_dict = { + "A": tfpd.Uniform(low = 1e-1, high = 2e+2), + "t0": tfpd.Uniform(low = 0.0 - 0.1, high = 1 + 0.1), + "sig": tfpd.Uniform(low = 0.5 * 1 / 20, high = 2 ), + "arn": tfpd.Uniform(low = 0.1 , high = 2 ), + "crn": tfpd.Uniform(low = jnp.log(1 /5), high = jnp.log(20)), + } - >>> prior_model = get_prior(params_list, prior_dict) + prior_model = get_prior(params_list, prior_dict) """ + if not jax_avail: + raise ImportError("Jax is required") + if not can_sample: raise ImportError("Jaxns not installed. Cannot make jaxns specific prior.") @@ -454,6 +459,9 @@ def get_likelihood(params_list, kernel_type, mean_type, **kwargs): The jaxns specific log likelihood function. """ + if not jax_avail: + raise ImportError("Jax is required") + if not can_make_gp: raise ImportError("Tinygp is required to make the GP model.") @@ -516,6 +524,9 @@ def sample(self, prior_model=None, likelihood_model=None, **kwargs): The results of the nested sampling process """ + if not jax_avail: + raise ImportError("Jax is required") + if not can_sample: raise ImportError("Jaxns not installed! Can't sample!") @@ -600,7 +611,7 @@ def posterior_plot(self, name: str, n=0, axis=None, save=False, filename=None): it can be a list like ``[xmin, xmax, ymin, ymax]`` or any other acceptable argument for ``matplotlib.pyplot.axis()`` method. - save : bool, optionalm, default ``False`` + save : bool, optional, default ``False`` If ``True``, save the figure with specified filename. filename : str @@ -638,7 +649,7 @@ def posterior_plot(self, name: str, n=0, axis=None, save=False, filename=None): return plt def weighted_posterior_plot( - self, name: str, n=0, rkey=random.PRNGKey(1234), axis=None, save=False, filename=None + self, name: str, n=0, rkey=None, axis=None, save=False, filename=None ): """ Returns the weighted posterior histogram for the given parameter @@ -673,6 +684,9 @@ def weighted_posterior_plot( plt : ``matplotlib.pyplot`` object Reference to plot, call ``show()`` to display it """ + if rkey is None: + rkey = random.PRNGKey(1234) + nsamples = self.Results.total_num_samples log_p = self.Results.log_dp_mean samples = self.Results.samples[name].reshape((nsamples, -1))[:, n] @@ -720,7 +734,7 @@ def corner_plot( param2: str, n1=0, n2=0, - rkey=random.PRNGKey(1234), + rkey=None, axis=None, save=False, filename=None, @@ -767,6 +781,9 @@ def corner_plot( plt : ``matplotlib.pyplot`` object Reference to plot, call ``show()`` to display it """ + if rkey is None: + rkey = random.PRNGKey(1234) + nsamples = self.Results.total_num_samples log_p = self.Results.log_dp_mean samples1 = self.Results.samples[param1].reshape((nsamples, -1))[:, n1] From f89acdd674e52d5cc55975f3dc5629b5498629b1 Mon Sep 17 00:00:00 2001 From: Gaurav17Joshi Date: Thu, 17 Aug 2023 14:01:42 +0530 Subject: [PATCH 15/50] added kernel and warnings --- stingray/modeling/gpmodeling.py | 12 ++++++++++++ stingray/modeling/tests/test_gpmodeling.py | 20 ++++++++++++++++++++ 2 files changed, 32 insertions(+) diff --git a/stingray/modeling/gpmodeling.py b/stingray/modeling/gpmodeling.py index 40977280e..051cf0b5c 100644 --- a/stingray/modeling/gpmodeling.py +++ b/stingray/modeling/gpmodeling.py @@ -77,6 +77,16 @@ def get_kernel(kernel_type, kernel_params): scale=1 / kernel_params["crn"], sigma=(kernel_params["arn"]) ** 0.5 ) return kernel + elif kernel_type == "QPO": + kernel = kernels.quasisep.Celerite( + a=kernel_params["aqpo"], + b=0.0, + c=kernel_params["cqpo"], + d=2 * jnp.pi * kernel_params["freq"], + ) + return kernel + else: + raise ValueError("Kernel type not implemented") def get_mean(mean_type, mean_params): @@ -109,6 +119,8 @@ def get_mean(mean_type, mean_params): mean = functools.partial(_skew_exponential, mean_params=mean_params) elif mean_type == "fred": mean = functools.partial(_fred, mean_params=mean_params) + else: + raise ValueError("Mean type not implemented") return mean diff --git a/stingray/modeling/tests/test_gpmodeling.py b/stingray/modeling/tests/test_gpmodeling.py index 74af0b2a0..a5a0dbd64 100644 --- a/stingray/modeling/tests/test_gpmodeling.py +++ b/stingray/modeling/tests/test_gpmodeling.py @@ -74,6 +74,22 @@ def test_get_kernel_rn(self): kernel_rn(self.x, jnp.array([0.0])) == kernel_rn_test(self.x, jnp.array([0.0])) ).all() + def test_get_kernel_qpo(self): + kernel_qpo = kernels.quasisep.Celerite( + a=1, + b=0.0, + c=1, + d=2 * jnp.pi * 1, + ) + kernel_qpo_test = get_kernel("QPO", self.kernel_params) + assert ( + kernel_qpo(self.x, jnp.array([0.0])) == kernel_qpo_test(self.x, jnp.array([0.0])) + ).all() + + def test_value_error(self): + with pytest.raises(ValueError, match="Kernel type not implemented"): + get_kernel("periodic", self.kernel_params) + class Testget_mean(object): def setup_class(self): @@ -147,6 +163,10 @@ def test_get_mean_fred(self): ) + 4.0 * jnp.exp(-5.0 * ((self.t + 0.4) / 0.7 + 0.7 / (self.t + 0.4))) * jnp.exp(2 * 5.0) assert (get_mean("fred", self.fred_mean_params)(self.t) == result_fred).all() + def test_value_error(self): + with pytest.raises(ValueError, match="Mean type not implemented"): + get_mean("polynomial", self.mean_params) + class Testget_gp_params(object): def setup_class(self): From 4f631e3bbc3d5c0f0116af4bc47f5b32abaad192 Mon Sep 17 00:00:00 2001 From: Gaurav17Joshi Date: Sun, 20 Aug 2023 14:11:44 +0530 Subject: [PATCH 16/50] Adding max samples option --- stingray/modeling/gpmodeling.py | 7 +++++-- stingray/modeling/tests/test_gpmodeling.py | 6 ++++-- 2 files changed, 9 insertions(+), 4 deletions(-) diff --git a/stingray/modeling/gpmodeling.py b/stingray/modeling/gpmodeling.py index 051cf0b5c..f85f13f3f 100644 --- a/stingray/modeling/gpmodeling.py +++ b/stingray/modeling/gpmodeling.py @@ -516,7 +516,7 @@ def __init__(self, Lc: Lightcurve) -> None: self.counts = Lc.counts self.Result = None - def sample(self, prior_model=None, likelihood_model=None, **kwargs): + def sample(self, prior_model=None, likelihood_model=None, max_samples=1e4): """ Makes a Jaxns nested sampler over the Gaussian Process, given the prior and likelihood model @@ -530,6 +530,9 @@ def sample(self, prior_model=None, likelihood_model=None, **kwargs): A likelihood fucntion which takes in the arguments of the prior model and returns the loglikelihood of the model + max_samples: int, default 1e4 + The maximum number of samples to be taken by the nested sampler + Returns ---------- Results: jaxns.results.NestedSamplerResults object @@ -548,7 +551,7 @@ def sample(self, prior_model=None, likelihood_model=None, **kwargs): NSmodel = Model(prior_model=self.prior_model, log_likelihood=self.likelihood_model) NSmodel.sanity_check(random.PRNGKey(10), S=100) - self.Exact_ns = ExactNestedSampler(NSmodel, num_live_points=500, max_samples=1e4) + self.Exact_ns = ExactNestedSampler(NSmodel, num_live_points=500, max_samples=max_samples) Termination_reason, State = self.Exact_ns( random.PRNGKey(42), term_cond=TerminationCondition(live_evidence_frac=1e-4) ) diff --git a/stingray/modeling/tests/test_gpmodeling.py b/stingray/modeling/tests/test_gpmodeling.py index a5a0dbd64..e65dff3f8 100644 --- a/stingray/modeling/tests/test_gpmodeling.py +++ b/stingray/modeling/tests/test_gpmodeling.py @@ -297,14 +297,16 @@ def setup_class(self): NSmodel = Model(prior_model=prior_model, log_likelihood=likelihood_model) NSmodel.sanity_check(random.PRNGKey(10), S=100) - Exact_ns = ExactNestedSampler(NSmodel, num_live_points=500, max_samples=1e4) + Exact_ns = ExactNestedSampler(NSmodel, num_live_points=500, max_samples=5e3) Termination_reason, State = Exact_ns( random.PRNGKey(42), term_cond=TerminationCondition(live_evidence_frac=1e-4) ) self.Results = Exact_ns.to_results(State, Termination_reason) self.gpresult = GPResult(lc) - self.gpresult.sample(prior_model=prior_model, likelihood_model=likelihood_model) + self.gpresult.sample( + prior_model=prior_model, likelihood_model=likelihood_model, max_samples=5e3 + ) def test_sample(self): for key in self.params_list: From 6c8ae22a6ddc8a5c13f10bff0b76787620aa1f43 Mon Sep 17 00:00:00 2001 From: Gaurav17Joshi Date: Sun, 20 Aug 2023 17:54:38 +0530 Subject: [PATCH 17/50] Added log parameters --- stingray/modeling/gpmodeling.py | 39 ++++-- stingray/modeling/tests/test_gpmodeling.py | 144 +++++++++------------ 2 files changed, 90 insertions(+), 93 deletions(-) diff --git a/stingray/modeling/gpmodeling.py b/stingray/modeling/gpmodeling.py index f85f13f3f..134ddf145 100644 --- a/stingray/modeling/gpmodeling.py +++ b/stingray/modeling/gpmodeling.py @@ -301,15 +301,20 @@ def _get_kernel_params(kernel_type): ---------- kernel_type: string The type of kernel to be used for the Gaussian Process model + The parameters in log scale have a prefix of "log_" Returns ------- A list of the parameters for the kernel for the GP model """ if kernel_type == "RN": - return ["arn", "crn"] + return ["log_arn", "log_crn"] elif kernel_type == "QPO_plus_RN": - return ["arn", "crn", "aqpo", "cqpo", "freq"] + return ["log_arn", "log_crn", "log_aqpo", "log_cqpo", "log_freq"] + elif kernel_type == "QPO": + return ["log_aqpo", "log_cqpo", "log_freq"] + else: + raise ValueError("Kernel type not implemented") def _get_mean_params(mean_type): @@ -320,19 +325,22 @@ def _get_mean_params(mean_type): ---------- mean_type: string The type of mean to be used for the Gaussian Process model + The parameters in log scale have a prefix of "log_" Returns ------- A list of the parameters for the mean for the GP model """ if (mean_type == "gaussian") or (mean_type == "exponential"): - return ["A", "t0", "sig"] + return ["log_A", "t0", "log_sig"] elif mean_type == "constant": - return ["A"] + return ["log_A"] elif (mean_type == "skew_gaussian") or (mean_type == "skew_exponential"): - return ["A", "t0", "sig1", "sig2"] + return ["log_A", "t0", "log_sig1", "log_sig2"] elif mean_type == "fred": - return ["A", "t0", "delta", "phi"] + return ["log_A", "t0", "delta", "phi"] + else: + raise ValueError("Mean type not implemented") def get_gp_params(kernel_type, mean_type): @@ -355,7 +363,7 @@ def get_gp_params(kernel_type, mean_type): Examples -------- get_gp_params("QPO_plus_RN", "gaussian") - ['arn', 'crn', 'aqpo', 'cqpo', 'freq', 'A', 't0', 'sig'] + ['log_arn', 'log_crn', 'log_aqpo', 'log_cqpo', 'log_freq', 'log_A', 't0', 'log_sig'] """ kernel_params = _get_kernel_params(kernel_type) mean_params = _get_mean_params(mean_type) @@ -381,6 +389,7 @@ def get_prior(params_list, prior_dict): or special priors from jaxns. **Note**: If jaxns priors are used, then the name given to them should be the same as the corresponding name in the params_list. + Also, if a parameter is to be used in the log scale, it should have a prefix of "log_" Returns ------- @@ -403,11 +412,11 @@ def get_prior(params_list, prior_dict): Make a prior dictionary using tensorflow_probability distributions prior_dict = { - "A": tfpd.Uniform(low = 1e-1, high = 2e+2), + "log_A": tfpd.Uniform(low = jnp.log(1e-1), high = jnp.log(2e+2)), "t0": tfpd.Uniform(low = 0.0 - 0.1, high = 1 + 0.1), - "sig": tfpd.Uniform(low = 0.5 * 1 / 20, high = 2 ), - "arn": tfpd.Uniform(low = 0.1 , high = 2 ), - "crn": tfpd.Uniform(low = jnp.log(1 /5), high = jnp.log(20)), + "log_sig": tfpd.Uniform(low = jnp.log(0.5 * 1 / 20), high = jnp.log(2) ), + "log_arn": tfpd.Uniform(low = jnp.log(0.1) , high = jnp.log(2) ), + "log_crn": tfpd.Uniform(low = jnp.log(1 /5), high = jnp.log(20)), } prior_model = get_prior(params_list, prior_dict) @@ -441,7 +450,8 @@ def get_likelihood(params_list, kernel_type, mean_type, **kwargs): Makes a jaxns specific log likelihood function which takes in the parameters in the order of the parameters list, and calculates the log likelihood of the data given the parameters, and the model - (kernel, mean) of the GP model. + (kernel, mean) of the GP model. **Note** Any parameters with a prefix + of "log_" are taken to be in the log scale. Parameters ---------- @@ -481,7 +491,10 @@ def get_likelihood(params_list, kernel_type, mean_type, **kwargs): def likelihood_model(*args): dict = {} for i, params in enumerate(params_list): - dict[params] = args[i] + if params[0:4] == "log_": + dict[params[4:]] = jnp.exp(args[i]) + else: + dict[params] = args[i] kernel = get_kernel(kernel_type=kernel_type, kernel_params=dict) mean = get_mean(mean_type=mean_type, mean_params=dict) gp = GaussianProcess(kernel, kwargs["Times"], mean_value=mean(kwargs["Times"])) diff --git a/stingray/modeling/tests/test_gpmodeling.py b/stingray/modeling/tests/test_gpmodeling.py index e65dff3f8..1d7f5eecd 100644 --- a/stingray/modeling/tests/test_gpmodeling.py +++ b/stingray/modeling/tests/test_gpmodeling.py @@ -173,81 +173,65 @@ def setup_class(self): pass def test_get_gp_params_rn(self): - assert get_gp_params("RN", "gaussian") == ["arn", "crn", "A", "t0", "sig"] - assert get_gp_params("RN", "constant") == ["arn", "crn", "A"] - assert get_gp_params("RN", "skew_gaussian") == ["arn", "crn", "A", "t0", "sig1", "sig2"] - assert get_gp_params("RN", "skew_exponential") == [ - "arn", - "crn", - "A", + assert get_gp_params("RN", "gaussian") == ["log_arn", "log_crn", "log_A", "t0", "log_sig"] + assert get_gp_params("RN", "constant") == ["log_arn", "log_crn", "log_A"] + assert get_gp_params("RN", "skew_gaussian") == [ + "log_arn", + "log_crn", + "log_A", "t0", - "sig1", - "sig2", + "log_sig1", + "log_sig2", ] - assert get_gp_params("RN", "exponential") == ["arn", "crn", "A", "t0", "sig"] - assert get_gp_params("RN", "fred") == ["arn", "crn", "A", "t0", "delta", "phi"] - - def test_get_gp_params_qpo_plus_rn(self): - assert get_gp_params("QPO_plus_RN", "gaussian") == [ - "arn", - "crn", - "aqpo", - "cqpo", - "freq", - "A", + assert get_gp_params("RN", "skew_exponential") == [ + "log_arn", + "log_crn", + "log_A", "t0", - "sig", - ] - assert get_gp_params("QPO_plus_RN", "constant") == [ - "arn", - "crn", - "aqpo", - "cqpo", - "freq", - "A", + "log_sig1", + "log_sig2", ] - assert get_gp_params("QPO_plus_RN", "skew_gaussian") == [ - "arn", - "crn", - "aqpo", - "cqpo", - "freq", - "A", + assert get_gp_params("RN", "exponential") == [ + "log_arn", + "log_crn", + "log_A", "t0", - "sig1", - "sig2", + "log_sig", ] - assert get_gp_params("QPO_plus_RN", "skew_exponential") == [ - "arn", - "crn", - "aqpo", - "cqpo", - "freq", - "A", + assert get_gp_params("RN", "fred") == [ + "log_arn", + "log_crn", + "log_A", "t0", - "sig1", - "sig2", + "delta", + "phi", ] - assert get_gp_params("QPO_plus_RN", "exponential") == [ - "arn", - "crn", - "aqpo", - "cqpo", - "freq", - "A", + + def test_get_gp_params_qpo_plus_rn(self): + assert get_gp_params("QPO_plus_RN", "gaussian") == [ + "log_arn", + "log_crn", + "log_aqpo", + "log_cqpo", + "log_freq", + "log_A", "t0", - "sig", + "log_sig", ] - assert get_gp_params("QPO_plus_RN", "fred") == [ - "arn", - "crn", - "aqpo", - "cqpo", - "freq", - "A", + with pytest.raises(ValueError, match="Mean type not implemented"): + get_gp_params("QPO_plus_RN", "notimplemented") + + with pytest.raises(ValueError, match="Kernel type not implemented"): + get_gp_params("notimplemented", "gaussian") + + def test_get_qpo(self): + assert get_gp_params("QPO", "gaussian") == [ + "log_aqpo", + "log_cqpo", + "log_freq", + "log_A", "t0", - "delta", - "phi", + "log_sig", ] @@ -278,11 +262,11 @@ def setup_class(self): # The prior dictionary, with suitable tfpd prior distributions prior_dict = { - "A": tfpd.Uniform(low=0.1 * span, high=2 * span), + "log_A": tfpd.Uniform(low=jnp.log(0.1 * span), high=jnp.log(2 * span)), "t0": tfpd.Uniform(low=self.Times[0] - 0.1 * T, high=self.Times[-1] + 0.1 * T), - "sig": tfpd.Uniform(low=0.5 * 1 / f, high=2 * T), - "arn": tfpd.Uniform(low=0.1 * span, high=2 * span), - "crn": tfpd.Uniform(low=jnp.log(1 / T), high=jnp.log(f)), + "log_sig": tfpd.Uniform(low=jnp.log(0.5 * 1 / f), high=jnp.log(2 * T)), + "log_arn": tfpd.Uniform(low=jnp.log(0.1 * span), high=jnp.log(2 * span)), + "log_crn": tfpd.Uniform(low=jnp.log(1 / T), high=jnp.log(f)), } prior_model = get_prior(self.params_list, prior_dict) @@ -339,15 +323,15 @@ def test_max_likelihood_parameters(self): assert key in self.gpresult.get_max_likelihood_parameters() def test_posterior_plot(self): - self.gpresult.posterior_plot("A") + self.gpresult.posterior_plot("log_A") assert plt.fignum_exists(1) def test_posterior_plot_labels_and_fname_default(self): clear_all_figs() - outfname = "A_Posterior_plot.png" + outfname = "log_A_Posterior_plot.png" if os.path.exists(outfname): os.unlink(outfname) - self.gpresult.posterior_plot("A", save=True) + self.gpresult.posterior_plot("log_A", save=True) assert os.path.exists(outfname) os.unlink(outfname) @@ -356,20 +340,20 @@ def test_posterior_plot_labels_and_fname(self): outfname = "blabla.png" if os.path.exists(outfname): os.unlink(outfname) - self.gpresult.posterior_plot("A", axis=[0, 14, 0, 0.5], save=True, filename=outfname) + self.gpresult.posterior_plot("log_A", axis=[0, 14, 0, 0.5], save=True, filename=outfname) assert os.path.exists(outfname) os.unlink(outfname) def test_weighted_posterior_plot(self): - self.gpresult.weighted_posterior_plot("A") + self.gpresult.weighted_posterior_plot("log_A") assert plt.fignum_exists(1) def test_weighted_posterior_plot_labels_and_fname_default(self): clear_all_figs() - outfname = "A_Weighted_Posterior_plot.png" + outfname = "log_A_Weighted_Posterior_plot.png" if os.path.exists(outfname): os.unlink(outfname) - self.gpresult.weighted_posterior_plot("A", save=True) + self.gpresult.weighted_posterior_plot("log_A", save=True) assert os.path.exists(outfname) os.unlink(outfname) @@ -379,21 +363,21 @@ def test_weighted_posterior_plot_labels_and_fname(self): if os.path.exists(outfname): os.unlink(outfname) self.gpresult.weighted_posterior_plot( - "A", axis=[0, 14, 0, 0.5], save=True, filename=outfname + "log_A", axis=[0, 14, 0, 0.5], save=True, filename=outfname ) assert os.path.exists(outfname) os.unlink(outfname) def test_corner_plot(self): - self.gpresult.corner_plot("A", "t0") + self.gpresult.corner_plot("log_A", "t0") assert plt.fignum_exists(1) def test_corner_plot_labels_and_fname_default(self): clear_all_figs() - outfname = "A_t0_Corner_plot.png" + outfname = "log_A_t0_Corner_plot.png" if os.path.exists(outfname): os.unlink(outfname) - self.gpresult.corner_plot("A", "t0", save=True) + self.gpresult.corner_plot("log_A", "t0", save=True) assert os.path.exists(outfname) os.unlink(outfname) @@ -402,6 +386,6 @@ def test_corner_plot_labels_and_fname(self): outfname = "blabla.png" if os.path.exists(outfname): os.unlink(outfname) - self.gpresult.corner_plot("A", "t0", axis=[0, 0.5, 0, 5], save=True, filename=outfname) + self.gpresult.corner_plot("log_A", "t0", axis=[0, 0.5, 0, 5], save=True, filename=outfname) assert os.path.exists(outfname) os.unlink(outfname) From bef7b1492c8d9b6945d67e0b6b7330c3f4f9e154 Mon Sep 17 00:00:00 2001 From: Gaurav17Joshi Date: Fri, 25 Aug 2023 21:27:31 +0530 Subject: [PATCH 18/50] Changed Function Names --- stingray/modeling/gpmodeling.py | 14 +++++++------- stingray/modeling/tests/test_gpmodeling.py | 20 +++++++++++--------- 2 files changed, 18 insertions(+), 16 deletions(-) diff --git a/stingray/modeling/gpmodeling.py b/stingray/modeling/gpmodeling.py index 134ddf145..b3e16cacb 100644 --- a/stingray/modeling/gpmodeling.py +++ b/stingray/modeling/gpmodeling.py @@ -444,7 +444,7 @@ def prior_model(): return prior_model -def get_likelihood(params_list, kernel_type, mean_type, **kwargs): +def get_log_likelihood(params_list, kernel_type, mean_type, **kwargs): """ A log likelihood generator function based on given values. Makes a jaxns specific log likelihood function which takes in the @@ -559,9 +559,9 @@ def sample(self, prior_model=None, likelihood_model=None, max_samples=1e4): raise ImportError("Jaxns not installed! Can't sample!") self.prior_model = prior_model - self.likelihood_model = likelihood_model + self.log_likelihood_model = likelihood_model - NSmodel = Model(prior_model=self.prior_model, log_likelihood=self.likelihood_model) + NSmodel = Model(prior_model=self.prior_model, log_likelihood=self.log_likelihood_model) NSmodel.sanity_check(random.PRNGKey(10), S=100) self.Exact_ns = ExactNestedSampler(NSmodel, num_live_points=500, max_samples=max_samples) @@ -756,7 +756,7 @@ def weighted_posterior_plot( plt.savefig(filename) return plt - def corner_plot( + def comparison_plot( self, param1: str, param2: str, @@ -768,7 +768,7 @@ def corner_plot( filename=None, ): """ - Plots the corner plot between two given parameters + Plots the comparison plot between two given parameters Parameters ---------- @@ -834,7 +834,7 @@ def corner_plot( density=True, cmap="GnBu", ) - plt.title("Corner Plot of " + str(param1) + " and " + str(param2)) + plt.title("Comparison Plot of " + str(param1) + " and " + str(param2)) plt.xlabel(param2) plt.ylabel(param1) plt.colorbar() @@ -843,7 +843,7 @@ def corner_plot( if save: if filename is None: - plt.savefig(str(param1) + "_" + str(param2) + "_Corner_plot.png") + plt.savefig(str(param1) + "_" + str(param2) + "_Comparison_plot.png") else: plt.savefig(filename) diff --git a/stingray/modeling/tests/test_gpmodeling.py b/stingray/modeling/tests/test_gpmodeling.py index 1d7f5eecd..02743af34 100644 --- a/stingray/modeling/tests/test_gpmodeling.py +++ b/stingray/modeling/tests/test_gpmodeling.py @@ -23,7 +23,7 @@ _HAS_TINYGP = False from stingray.modeling.gpmodeling import get_kernel, get_mean, get_gp_params -from stingray.modeling.gpmodeling import get_prior, get_likelihood, GPResult +from stingray.modeling.gpmodeling import get_prior, get_log_likelihood, GPResult from stingray import Lightcurve try: @@ -270,7 +270,7 @@ def setup_class(self): } prior_model = get_prior(self.params_list, prior_dict) - likelihood_model = get_likelihood( + likelihood_model = get_log_likelihood( self.params_list, kernel_type="RN", mean_type="gaussian", @@ -368,24 +368,26 @@ def test_weighted_posterior_plot_labels_and_fname(self): assert os.path.exists(outfname) os.unlink(outfname) - def test_corner_plot(self): - self.gpresult.corner_plot("log_A", "t0") + def test_comparison_plot(self): + self.gpresult.comparison_plot("log_A", "t0") assert plt.fignum_exists(1) - def test_corner_plot_labels_and_fname_default(self): + def test_comparison_plot_labels_and_fname_default(self): clear_all_figs() - outfname = "log_A_t0_Corner_plot.png" + outfname = "log_A_t0_Comparison_plot.png" if os.path.exists(outfname): os.unlink(outfname) - self.gpresult.corner_plot("log_A", "t0", save=True) + self.gpresult.comparison_plot("log_A", "t0", save=True) assert os.path.exists(outfname) os.unlink(outfname) - def test_corner_plot_labels_and_fname(self): + def test_comparison_plot_labels_and_fname(self): clear_all_figs() outfname = "blabla.png" if os.path.exists(outfname): os.unlink(outfname) - self.gpresult.corner_plot("log_A", "t0", axis=[0, 0.5, 0, 5], save=True, filename=outfname) + self.gpresult.comparison_plot( + "log_A", "t0", axis=[0, 0.5, 0, 5], save=True, filename=outfname + ) assert os.path.exists(outfname) os.unlink(outfname) From 8701ae7b280a73c1bf055bc93aed0ba9968b3f86 Mon Sep 17 00:00:00 2001 From: Gaurav17Joshi Date: Fri, 25 Aug 2023 23:05:03 +0530 Subject: [PATCH 19/50] Docstring changes --- stingray/modeling/gpmodeling.py | 147 +++++++++++---------- stingray/modeling/tests/test_gpmodeling.py | 4 +- 2 files changed, 76 insertions(+), 75 deletions(-) diff --git a/stingray/modeling/gpmodeling.py b/stingray/modeling/gpmodeling.py index b3e16cacb..e52a71777 100644 --- a/stingray/modeling/gpmodeling.py +++ b/stingray/modeling/gpmodeling.py @@ -49,7 +49,8 @@ def get_kernel(kernel_type, kernel_params): ---------- kernel_type: string The type of kernel to be used for the Gaussian Process - To be selected from the kernels already implemented + To be selected from the kernels already implemented: + ["RN", "QPO", "QPO_plus_RN"] kernel_params: dict Dictionary containing the parameters for the kernel @@ -97,7 +98,9 @@ def get_mean(mean_type, mean_params): ---------- mean_type: string The type of mean to be used for the Gaussian Process - To be selected from the mean functions already implemented + To be selected from the mean functions already implemented: + ["gaussian", "exponential", "constant", "skew_gaussian", + "skew_exponential", "fred"] mean_params: dict Dictionary containing the parameters for the mean @@ -216,7 +219,7 @@ def _skew_gaussian(t, mean_params): sig1 = jnp.atleast_1d(mean_params["sig1"])[:, jnp.newaxis] sig2 = jnp.atleast_1d(mean_params["sig2"])[:, jnp.newaxis] - return jnp.sum( + y = jnp.sum( A * jnp.where( t > t0, @@ -225,6 +228,7 @@ def _skew_gaussian(t, mean_params): ), axis=0, ) + return y def _skew_exponential(t, mean_params): @@ -252,7 +256,7 @@ def _skew_exponential(t, mean_params): sig1 = jnp.atleast_1d(mean_params["sig1"])[:, jnp.newaxis] sig2 = jnp.atleast_1d(mean_params["sig2"])[:, jnp.newaxis] - return jnp.sum( + y = jnp.sum( A * jnp.where( t > t0, @@ -261,6 +265,7 @@ def _skew_exponential(t, mean_params): ), axis=0, ) + return y def _fred(t, mean_params): @@ -351,10 +356,13 @@ def get_gp_params(kernel_type, mean_type): Parameters ---------- kernel_type: string - The type of kernel to be used for the Gaussian Process model + The type of kernel to be used for the Gaussian Process model: + ["RN", "QPO", "QPO_plus_RN"] mean_type: string - The type of mean to be used for the Gaussian Process model + The type of mean to be used for the Gaussian Process model: + ["gaussian", "exponential", "constant", "skew_gaussian", + "skew_exponential", "fred"] Returns ------- @@ -402,15 +410,11 @@ def get_prior(params_list, prior_dict): Examples -------- A prior function for a Red Noise kernel and a Gaussian mean function - Obain the parameters list - if not can_sample: - pytest.skip("Jaxns not installed. Cannot make jaxns specific prior.") - if not tfp_available: - pytest.skip("Tensorflow probability required to make priors.") + # Obtain the parameters list params_list = get_gp_params("RN", "gaussian") - Make a prior dictionary using tensorflow_probability distributions + # Make a prior dictionary using tensorflow_probability distributions prior_dict = { "log_A": tfpd.Uniform(low = jnp.log(1e-1), high = jnp.log(2e+2)), "t0": tfpd.Uniform(low = 0.0 - 0.1, high = 1 + 0.1), @@ -444,7 +448,7 @@ def prior_model(): return prior_model -def get_log_likelihood(params_list, kernel_type, mean_type, **kwargs): +def get_log_likelihood(params_list, kernel_type, mean_type, times, counts, **kwargs): """ A log likelihood generator function based on given values. Makes a jaxns specific log likelihood function which takes in the @@ -462,23 +466,23 @@ def get_log_likelihood(params_list, kernel_type, mean_type, **kwargs): A dictionary of the priors of parameters to be used. kernel_type: - The type of kernel to be used in the model. + The type of kernel to be used in the model: + ["RN", "QPO", "QPO_plus_RN"] mean_type: - The type of mean to be used in the model. + The type of mean to be used in the model: + ["gaussian", "exponential", "constant", "skew_gaussian", + "skew_exponential", "fred"] + + times: np.array or jnp.array + The time array of the lightcurve - **kwargs: - The keyword arguments to be used in the log likelihood function. - **Note**: The keyword arguments Times and counts are necessary for - calculating the log likelihood. - Times: np.array or jnp.array - The time array of the lightcurve - counts: np.array or jnp.array - The photon counts array of the lightcurve + counts: np.array or jnp.array + The photon counts array of the lightcurve Returns ------- - The jaxns specific log likelihood function. + The Jaxns specific log likelihood function. """ if not jax_avail: @@ -497,8 +501,8 @@ def likelihood_model(*args): dict[params] = args[i] kernel = get_kernel(kernel_type=kernel_type, kernel_params=dict) mean = get_mean(mean_type=mean_type, mean_params=dict) - gp = GaussianProcess(kernel, kwargs["Times"], mean_value=mean(kwargs["Times"])) - return gp.log_probability(kwargs["counts"]) + gp = GaussianProcess(kernel, times, mean_value=mean(times)) + return gp.log_probability(counts) return likelihood_model @@ -513,21 +517,13 @@ class GPResult: lc: Stingray.Lightcurve object The lightcurve on which the bayesian inference is to be done - Other Parameters - ---------------- - time : class: np.array - The array containing the times of the lightcurve - - counts : class: np.array - The array containing the photon counts of the lightcurve - """ - def __init__(self, Lc: Lightcurve) -> None: - self.lc = Lc - self.time = Lc.time - self.counts = Lc.counts - self.Result = None + def __init__(self, lc: Lightcurve) -> None: + self.lc = lc + self.time = lc.time + self.counts = lc.counts + self.result = None def sample(self, prior_model=None, likelihood_model=None, max_samples=1e4): """ @@ -537,18 +533,23 @@ def sample(self, prior_model=None, likelihood_model=None, max_samples=1e4): Parameters ---------- prior_model: jaxns.prior.PriorModelType object - A prior generator object + A prior generator object. + Can be made using the get_prior function or can use your own jaxns + compatible prior function. likelihood_model: jaxns.types.LikelihoodType object A likelihood fucntion which takes in the arguments of the prior - model and returns the loglikelihood of the model + model and returns the loglikelihood of the model. + Can be made using the get_log_likelihood function or can use your own + log_likelihood function with same order of arguments as the prior_model. + max_samples: int, default 1e4 The maximum number of samples to be taken by the nested sampler Returns ---------- - Results: jaxns.results.NestedSamplerResults object + results: jaxns.results.NestedSamplerResults object The results of the nested sampling process """ @@ -564,49 +565,49 @@ def sample(self, prior_model=None, likelihood_model=None, max_samples=1e4): NSmodel = Model(prior_model=self.prior_model, log_likelihood=self.log_likelihood_model) NSmodel.sanity_check(random.PRNGKey(10), S=100) - self.Exact_ns = ExactNestedSampler(NSmodel, num_live_points=500, max_samples=max_samples) - Termination_reason, State = self.Exact_ns( + self.exact_ns = ExactNestedSampler(NSmodel, num_live_points=500, max_samples=max_samples) + termination_reason, State = self.exact_ns( random.PRNGKey(42), term_cond=TerminationCondition(live_evidence_frac=1e-4) ) - self.Results = self.Exact_ns.to_results(State, Termination_reason) + self.results = self.exact_ns.to_results(State, termination_reason) print("Simulation Complete") def get_evidence(self): """ Returns the log evidence of the model """ - return self.Results.log_Z_mean + return self.results.log_Z_mean def print_summary(self): """ Prints a summary table for the model parameters """ - self.Exact_ns.summary(self.Results) + self.exact_ns.summary(self.results) def plot_diagnostics(self): """ Plots the diagnostic plots for the sampling process """ - self.Exact_ns.plot_diagnostics(self.Results) + self.exact_ns.plot_diagnostics(self.results) def plot_cornerplot(self): """ Plots the corner plot for the sampled hyperparameters """ - self.Exact_ns.plot_cornerplot(self.Results) + self.exact_ns.plot_cornerplot(self.results) def get_parameters_names(self): """ Returns the names of the parameters """ - return sorted(self.Results.samples.keys()) + return sorted(self.results.samples.keys()) def get_max_posterior_parameters(self): """ Returns the optimal parameters for the model based on the NUTS sampling """ - max_post_idx = jnp.argmax(self.Results.log_posterior_density) - map_points = jax.tree_map(lambda x: x[max_post_idx], self.Results.samples) + max_post_idx = jnp.argmax(self.results.log_posterior_density) + map_points = jax.tree_map(lambda x: x[max_post_idx], self.results.samples) return map_points @@ -614,8 +615,8 @@ def get_max_likelihood_parameters(self): """ Retruns the maximum likelihood parameters """ - max_like_idx = jnp.argmax(self.Results.log_L_samples) - max_like_points = jax.tree_map(lambda x: x[max_like_idx], self.Results.samples) + max_like_idx = jnp.argmax(self.results.log_L_samples) + max_like_points = jax.tree_map(lambda x: x[max_like_idx], self.results.samples) return max_like_points @@ -631,7 +632,7 @@ def posterior_plot(self, name: str, n=0, axis=None, save=False, filename=None): used in the prior_function n : int, default 0 - The index of the parameter to be plotted. + The index of the parameter to be plotted (for multi component parameters). For multivariate parameters, the index of the specific parameter to be plotted. axis : list, tuple, string, default ``None`` @@ -651,13 +652,13 @@ def posterior_plot(self, name: str, n=0, axis=None, save=False, filename=None): Reference to plot, call ``show()`` to display it """ - nsamples = self.Results.total_num_samples - samples = self.Results.samples[name].reshape((nsamples, -1))[:, n] + nsamples = self.results.total_num_samples + samples = self.results.samples[name].reshape((nsamples, -1))[:, n] plt.hist( samples, bins="auto", density=True, alpha=1.0, label=name, fc="None", edgecolor="black" ) - mean1 = jnp.mean(self.Results.samples[name]) - std1 = jnp.std(self.Results.samples[name]) + mean1 = jnp.mean(self.results.samples[name]) + std1 = jnp.std(self.results.samples[name]) plt.axvline(mean1, color="red", linestyle="dashed", label="mean") plt.axvline(mean1 + std1, color="green", linestyle="dotted") plt.axvline(mean1 - std1, linestyle="dotted", color="green") @@ -690,7 +691,7 @@ def weighted_posterior_plot( used in the prior_function n : int, default 0 - The index of the parameter to be plotted. + The index of the parameter to be plotted (for multi component parameters). For multivariate parameters, the index of the specific parameter to be plotted. key: jax.random.PRNGKey, default ``random.PRNGKey(1234)`` @@ -715,17 +716,17 @@ def weighted_posterior_plot( if rkey is None: rkey = random.PRNGKey(1234) - nsamples = self.Results.total_num_samples - log_p = self.Results.log_dp_mean - samples = self.Results.samples[name].reshape((nsamples, -1))[:, n] + nsamples = self.results.total_num_samples + log_p = self.results.log_dp_mean + samples = self.results.samples[name].reshape((nsamples, -1))[:, n] weights = jnp.where(jnp.isfinite(samples), jnp.exp(log_p), 0.0) log_weights = jnp.where(jnp.isfinite(samples), log_p, -jnp.inf) samples_resampled = resample( - rkey, samples, log_weights, S=max(10, int(self.Results.ESS)), replace=True + rkey, samples, log_weights, S=max(10, int(self.results.ESS)), replace=True ) - nbins = max(10, int(jnp.sqrt(self.Results.ESS)) + 1) + nbins = max(10, int(jnp.sqrt(self.results.ESS)) + 1) binsx = jnp.linspace(*jnp.percentile(samples_resampled, jnp.asarray([0, 100])), 2 * nbins) plt.hist( @@ -783,11 +784,11 @@ def comparison_plot( used in the prior_function n1 : int, default 0 - The index of the first parameter to be plotted. + The index of the first parameter to be plotted (for multi component parameters). For multivariate parameters, the index of the specific parameter to be plotted. n2 : int, default 0 - The index of the second parameter to be plotted. + The index of the second parameter to be plotted (for multi component parameters). For multivariate parameters, the index of the specific parameter to be plotted. key: jax.random.PRNGKey, default ``random.PRNGKey(1234)`` @@ -812,19 +813,19 @@ def comparison_plot( if rkey is None: rkey = random.PRNGKey(1234) - nsamples = self.Results.total_num_samples - log_p = self.Results.log_dp_mean - samples1 = self.Results.samples[param1].reshape((nsamples, -1))[:, n1] - samples2 = self.Results.samples[param2].reshape((nsamples, -1))[:, n2] + nsamples = self.results.total_num_samples + log_p = self.results.log_dp_mean + samples1 = self.results.samples[param1].reshape((nsamples, -1))[:, n1] + samples2 = self.results.samples[param2].reshape((nsamples, -1))[:, n2] log_weights = jnp.where(jnp.isfinite(samples2), log_p, -jnp.inf) - nbins = max(10, int(jnp.sqrt(self.Results.ESS)) + 1) + nbins = max(10, int(jnp.sqrt(self.results.ESS)) + 1) samples_resampled = resample( rkey, jnp.stack([samples1, samples2], axis=-1), log_weights, - S=max(10, int(self.Results.ESS)), + S=max(10, int(self.results.ESS)), replace=True, ) plt.hist2d( diff --git a/stingray/modeling/tests/test_gpmodeling.py b/stingray/modeling/tests/test_gpmodeling.py index 02743af34..722467ce3 100644 --- a/stingray/modeling/tests/test_gpmodeling.py +++ b/stingray/modeling/tests/test_gpmodeling.py @@ -274,7 +274,7 @@ def setup_class(self): self.params_list, kernel_type="RN", mean_type="gaussian", - Times=self.Times, + times=self.Times, counts=self.counts, ) @@ -294,7 +294,7 @@ def setup_class(self): def test_sample(self): for key in self.params_list: - assert (self.Results.samples[key]).all() == (self.gpresult.Results.samples[key]).all() + assert (self.Results.samples[key]).all() == (self.gpresult.results.samples[key]).all() def test_get_evidence(self): assert self.Results.log_Z_mean == self.gpresult.get_evidence() From 51a9cef3b251ad785549600ac52409ba7092115e Mon Sep 17 00:00:00 2001 From: Gaurav17Joshi Date: Mon, 28 Aug 2023 21:41:12 +0530 Subject: [PATCH 20/50] Changelog Changed --- docs/changes/739.feature.rst | 20 +++++++++++++++++++- 1 file changed, 19 insertions(+), 1 deletion(-) diff --git a/docs/changes/739.feature.rst b/docs/changes/739.feature.rst index b1661b37c..ff7cecff8 100644 --- a/docs/changes/739.feature.rst +++ b/docs/changes/739.feature.rst @@ -1 +1,19 @@ -A feature dealing with Gaussian Processes for Qpo analysis \ No newline at end of file +This is a JAX implementation of the GP tool by `Hubener et al `_ +for QPO detection and parameter analysis. + +This feature makes use of tinygp library for Gaussian Processes implementation, and jaxns for nested sampling, +and is kept in the stingray.modeling.gpmodelling module. + +Main features of the module are: + +- get_prior: This function makes the prior function for a specified prior dictionary. +- get_likelihood: This function makes the log_likelihood function for the given Kernel, Mean model and lightcurve. +- GPResult class: The class which takes a Lightcurve, and performs Nested Sampling for a given prior and likelihood. + +The additional Dependencies for the code +- jax +- tinygp +- jaxns +- etils +- tensorflow_probability +- typing_extensions \ No newline at end of file From f9d38ca8ee37059da5871ae64f88713c512835bc Mon Sep 17 00:00:00 2001 From: Gaurav17Joshi Date: Tue, 29 Aug 2023 01:26:57 +0530 Subject: [PATCH 21/50] Improved get_mean docs --- stingray/modeling/gpmodeling.py | 185 ++++++++++++++++++++------------ 1 file changed, 115 insertions(+), 70 deletions(-) diff --git a/stingray/modeling/gpmodeling.py b/stingray/modeling/gpmodeling.py index e52a71777..8780e6af6 100644 --- a/stingray/modeling/gpmodeling.py +++ b/stingray/modeling/gpmodeling.py @@ -92,7 +92,7 @@ def get_kernel(kernel_type, kernel_params): def get_mean(mean_type, mean_params): """ - Function for producing the mean for the Gaussian Process. + Function for producing the mean function for the Gaussian Process. Parameters ---------- @@ -106,118 +106,153 @@ def get_mean(mean_type, mean_params): Dictionary containing the parameters for the mean Should contain the parameters for the selected mean + Returns + ------- + A function which takes in the time coordinates and returns the mean values. + + Examples + -------- + Unimodal Gaussian Mean Function: + mean_params = {"A": 3.0, "t0": 0.2, "sig1": 0.1, "sig2": 0.4} + mean = get_mean("gaussian", mean_params) + + Multimodal Gaussian Mean Function: + mean_params = {"A": jnp.array([3.0, 4.0]), "t0": jnp.array([0.2, 1]), + "sig1": jnp.array([0.1, 0.4]), "sig2": jnp.array([0.4, 0.1])} + mean = get_mean("gaussian", mean_params) + """ if not jax_avail: raise ImportError("Jax is required") if mean_type == "gaussian": - mean = functools.partial(_gaussian, mean_params=mean_params) + mean = functools.partial(_gaussian, params=mean_params) elif mean_type == "exponential": - mean = functools.partial(_exponential, mean_params=mean_params) + mean = functools.partial(_exponential, params=mean_params) elif mean_type == "constant": - mean = functools.partial(_constant, mean_params=mean_params) + mean = functools.partial(_constant, params=mean_params) elif mean_type == "skew_gaussian": - mean = functools.partial(_skew_gaussian, mean_params=mean_params) + mean = functools.partial(_skew_gaussian, params=mean_params) elif mean_type == "skew_exponential": - mean = functools.partial(_skew_exponential, mean_params=mean_params) + mean = functools.partial(_skew_exponential, params=mean_params) elif mean_type == "fred": - mean = functools.partial(_fred, mean_params=mean_params) + mean = functools.partial(_fred, params=mean_params) else: raise ValueError("Mean type not implemented") return mean -def _gaussian(t, mean_params): +def _gaussian(t, params): """A gaussian flare shape. Parameters ---------- t: jnp.ndarray The time coordinates. - A: jnp.int - Amplitude of the flare. - t0: - The location of the maximum. - sig1: - The width parameter for the gaussian. + + params: dict + The dictionary contating parameter values of the gaussian flare. + + The parameters for the gaussian flare are: + A: jnp.float / jnp.ndarray + Amplitude of the flare. + t0: jnp.float / jnp.ndarray + The location of the maximum. + sig1: jnp.float / jnp.ndarray + The width parameter for the gaussian. Returns ------- The y values for the gaussian flare. """ - A = jnp.atleast_1d(mean_params["A"])[:, jnp.newaxis] - t0 = jnp.atleast_1d(mean_params["t0"])[:, jnp.newaxis] - sig = jnp.atleast_1d(mean_params["sig"])[:, jnp.newaxis] + A = jnp.atleast_1d(params["A"])[:, jnp.newaxis] + t0 = jnp.atleast_1d(params["t0"])[:, jnp.newaxis] + sig = jnp.atleast_1d(params["sig"])[:, jnp.newaxis] return jnp.sum(A * jnp.exp(-((t - t0) ** 2) / (2 * (sig**2))), axis=0) -def _exponential(t, mean_params): +def _exponential(t, params): """An exponential flare shape. Parameters ---------- t: jnp.ndarray The time coordinates. - A: jnp.int - Amplitude of the flare. - t0: - The location of the maximum. - sig1: - The width parameter for the exponential. + + params: dict + The dictionary contating parameter values of the exponential flare. + + The parameters for the exponential flare are: + A: jnp.float / jnp.ndarray + Amplitude of the flare. + t0: jnp.float / jnp.ndarray + The location of the maximum. + sig1: jnp.float / jnp.ndarray + The width parameter for the exponential. Returns ------- The y values for exponential flare. """ - A = jnp.atleast_1d(mean_params["A"])[:, jnp.newaxis] - t0 = jnp.atleast_1d(mean_params["t0"])[:, jnp.newaxis] - sig = jnp.atleast_1d(mean_params["sig"])[:, jnp.newaxis] + A = jnp.atleast_1d(params["A"])[:, jnp.newaxis] + t0 = jnp.atleast_1d(params["t0"])[:, jnp.newaxis] + sig = jnp.atleast_1d(params["sig"])[:, jnp.newaxis] return jnp.sum(A * jnp.exp(-jnp.abs(t - t0) / (2 * (sig**2))), axis=0) -def _constant(t, mean_params): +def _constant(t, params): """A constant mean shape. Parameters ---------- t: jnp.ndarray The time coordinates. - A: jnp.int - Constant amplitude of the flare. + + params: dict + The dictionary contating parameter values of the constant flare. + + The parameters for the constant flare are: + A: jnp.float + Constant amplitude of the flare. Returns ------- The constant value. """ - return mean_params["A"] * jnp.ones_like(t) + return params["A"] * jnp.ones_like(t) -def _skew_gaussian(t, mean_params): +def _skew_gaussian(t, params): """A skew gaussian flare shape. Parameters ---------- t: jnp.ndarray The time coordinates. - A: jnp.int - Amplitude of the flare. - t0: - The location of the maximum. - sig1: - The width parameter for the rising edge. - sig2: - The width parameter for the falling edge. + + params: dict + The dictionary contating parameter values of the skew gaussian flare. + + The parameters for the skew gaussian flare are: + A: jnp.float / jnp.ndarray + Amplitude of the flare. + t0: jnp.float / jnp.ndarray + The location of the maximum. + sig1: jnp.float / jnp.ndarray + The width parameter for the rising edge. + sig2: jnp.float / jnp.ndarray + The width parameter for the falling edge. Returns ------- The y values for skew gaussian flare. """ - A = jnp.atleast_1d(mean_params["A"])[:, jnp.newaxis] - t0 = jnp.atleast_1d(mean_params["t0"])[:, jnp.newaxis] - sig1 = jnp.atleast_1d(mean_params["sig1"])[:, jnp.newaxis] - sig2 = jnp.atleast_1d(mean_params["sig2"])[:, jnp.newaxis] + A = jnp.atleast_1d(params["A"])[:, jnp.newaxis] + t0 = jnp.atleast_1d(params["t0"])[:, jnp.newaxis] + sig1 = jnp.atleast_1d(params["sig1"])[:, jnp.newaxis] + sig2 = jnp.atleast_1d(params["sig2"])[:, jnp.newaxis] y = jnp.sum( A @@ -231,30 +266,35 @@ def _skew_gaussian(t, mean_params): return y -def _skew_exponential(t, mean_params): +def _skew_exponential(t, params): """A skew exponential flare shape. Parameters ---------- t: jnp.ndarray The time coordinates. - A: jnp.int - Amplitude of the flare. - t0: - The location of the maximum. - sig1: - The width parameter for the rising edge. - sig2: - The width parameter for the falling edge. + + params: dict + The dictionary contating parameter values of the skew exponential flare. + + The parameters for the skew exponential flare are: + A: jnp.float / jnp.ndarray + Amplitude of the flare. + t0: jnp.float / jnp.ndarray + The location of the maximum. + sig1: jnp.float / jnp.ndarray + The width parameter for the rising edge. + sig2: jnp.float / jnp.ndarray + The width parameter for the falling edge. Returns ------- The y values for exponential flare. """ - A = jnp.atleast_1d(mean_params["A"])[:, jnp.newaxis] - t0 = jnp.atleast_1d(mean_params["t0"])[:, jnp.newaxis] - sig1 = jnp.atleast_1d(mean_params["sig1"])[:, jnp.newaxis] - sig2 = jnp.atleast_1d(mean_params["sig2"])[:, jnp.newaxis] + A = jnp.atleast_1d(params["A"])[:, jnp.newaxis] + t0 = jnp.atleast_1d(params["t0"])[:, jnp.newaxis] + sig1 = jnp.atleast_1d(params["sig1"])[:, jnp.newaxis] + sig2 = jnp.atleast_1d(params["sig2"])[:, jnp.newaxis] y = jnp.sum( A @@ -268,30 +308,35 @@ def _skew_exponential(t, mean_params): return y -def _fred(t, mean_params): +def _fred(t, params): """A fast rise exponential decay (FRED) flare shape. Parameters ---------- t: jnp.ndarray The time coordinates. - A: jnp.int - Amplitude of the flare. - t0: - The location of the maximum. - phi: - Symmetry parameter of the flare. - delta: - Offset parameter of the flare. + + params: dict + The dictionary contating parameter values of the FRED flare. + + The parameters for the FRED flare are: + A: jnp.float / jnp.ndarray + Amplitude of the flare. + t0: jnp.float / jnp.ndarray + The location of the maximum. + phi: jnp.float / jnp.ndarray + Symmetry parameter of the flare. + delta: jnp.float / jnp.ndarray + Offset parameter of the flare. Returns ------- The y values for exponential flare. """ - A = jnp.atleast_1d(mean_params["A"])[:, jnp.newaxis] - t0 = jnp.atleast_1d(mean_params["t0"])[:, jnp.newaxis] - phi = jnp.atleast_1d(mean_params["phi"])[:, jnp.newaxis] - delta = jnp.atleast_1d(mean_params["delta"])[:, jnp.newaxis] + A = jnp.atleast_1d(params["A"])[:, jnp.newaxis] + t0 = jnp.atleast_1d(params["t0"])[:, jnp.newaxis] + phi = jnp.atleast_1d(params["phi"])[:, jnp.newaxis] + delta = jnp.atleast_1d(params["delta"])[:, jnp.newaxis] return jnp.sum( A * jnp.exp(-phi * ((t + delta) / t0 + t0 / (t + delta))) * jnp.exp(2 * phi), axis=0 From 3a46a45e206ea85e3f89187b291455400aa2d997 Mon Sep 17 00:00:00 2001 From: Gaurav17Joshi Date: Wed, 30 Aug 2023 16:27:00 +0530 Subject: [PATCH 22/50] Docstrings Updated --- stingray/modeling/gpmodeling.py | 20 ++++++++++++-------- stingray/modeling/tests/test_gpmodeling.py | 4 +++- 2 files changed, 15 insertions(+), 9 deletions(-) diff --git a/stingray/modeling/gpmodeling.py b/stingray/modeling/gpmodeling.py index 8780e6af6..2f7e799b0 100644 --- a/stingray/modeling/gpmodeling.py +++ b/stingray/modeling/gpmodeling.py @@ -113,13 +113,13 @@ def get_mean(mean_type, mean_params): Examples -------- Unimodal Gaussian Mean Function: - mean_params = {"A": 3.0, "t0": 0.2, "sig1": 0.1, "sig2": 0.4} + mean_params = {"A": 3.0, "t0": 0.2, "sig": 0.1} mean = get_mean("gaussian", mean_params) - Multimodal Gaussian Mean Function: + Multimodal Skew Gaussian Mean Function: mean_params = {"A": jnp.array([3.0, 4.0]), "t0": jnp.array([0.2, 1]), "sig1": jnp.array([0.1, 0.4]), "sig2": jnp.array([0.4, 0.1])} - mean = get_mean("gaussian", mean_params) + mean = get_mean("skew_gaussian", mean_params) """ if not jax_avail: @@ -338,10 +338,12 @@ def _fred(t, params): phi = jnp.atleast_1d(params["phi"])[:, jnp.newaxis] delta = jnp.atleast_1d(params["delta"])[:, jnp.newaxis] - return jnp.sum( + y = jnp.sum( A * jnp.exp(-phi * ((t + delta) / t0 + t0 / (t + delta))) * jnp.exp(2 * phi), axis=0 ) + return y + def _get_kernel_params(kernel_type): """ @@ -485,8 +487,10 @@ def prior_model(): for i in params_list: if isinstance(prior_dict[i], tfpd.Distribution): parameter = yield Prior(prior_dict[i], name=i) - else: + elif isinstance(prior_dict[i], Prior): parameter = yield prior_dict[i] + else: + raise ValueError("Prior should be a tfpd distribution or a jaxns prior.") prior_list.append(parameter) return tuple(prior_list) @@ -607,10 +611,10 @@ def sample(self, prior_model=None, likelihood_model=None, max_samples=1e4): self.prior_model = prior_model self.log_likelihood_model = likelihood_model - NSmodel = Model(prior_model=self.prior_model, log_likelihood=self.log_likelihood_model) - NSmodel.sanity_check(random.PRNGKey(10), S=100) + nsmodel = Model(prior_model=self.prior_model, log_likelihood=self.log_likelihood_model) + nsmodel.sanity_check(random.PRNGKey(10), S=100) - self.exact_ns = ExactNestedSampler(NSmodel, num_live_points=500, max_samples=max_samples) + self.exact_ns = ExactNestedSampler(nsmodel, num_live_points=500, max_samples=max_samples) termination_reason, State = self.exact_ns( random.PRNGKey(42), term_cond=TerminationCondition(live_evidence_frac=1e-4) ) diff --git a/stingray/modeling/tests/test_gpmodeling.py b/stingray/modeling/tests/test_gpmodeling.py index 722467ce3..dca1e48ca 100644 --- a/stingray/modeling/tests/test_gpmodeling.py +++ b/stingray/modeling/tests/test_gpmodeling.py @@ -262,7 +262,9 @@ def setup_class(self): # The prior dictionary, with suitable tfpd prior distributions prior_dict = { - "log_A": tfpd.Uniform(low=jnp.log(0.1 * span), high=jnp.log(2 * span)), + "log_A": Prior( + tfpd.Uniform(low=jnp.log(0.1 * span), high=jnp.log(2 * span)), name="log_A" + ), "t0": tfpd.Uniform(low=self.Times[0] - 0.1 * T, high=self.Times[-1] + 0.1 * T), "log_sig": tfpd.Uniform(low=jnp.log(0.5 * 1 / f), high=jnp.log(2 * T)), "log_arn": tfpd.Uniform(low=jnp.log(0.1 * span), high=jnp.log(2 * span)), From d86b27403a5de3566118363b7b96ef9c529d2dc0 Mon Sep 17 00:00:00 2001 From: Gaurav17Joshi Date: Tue, 5 Sep 2023 19:28:55 +0530 Subject: [PATCH 23/50] Small doc change --- stingray/modeling/gpmodeling.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/stingray/modeling/gpmodeling.py b/stingray/modeling/gpmodeling.py index 2f7e799b0..5c953604a 100644 --- a/stingray/modeling/gpmodeling.py +++ b/stingray/modeling/gpmodeling.py @@ -615,10 +615,10 @@ def sample(self, prior_model=None, likelihood_model=None, max_samples=1e4): nsmodel.sanity_check(random.PRNGKey(10), S=100) self.exact_ns = ExactNestedSampler(nsmodel, num_live_points=500, max_samples=max_samples) - termination_reason, State = self.exact_ns( + termination_reason, state = self.exact_ns( random.PRNGKey(42), term_cond=TerminationCondition(live_evidence_frac=1e-4) ) - self.results = self.exact_ns.to_results(State, termination_reason) + self.results = self.exact_ns.to_results(state, termination_reason) print("Simulation Complete") def get_evidence(self): From fa329cbcfc7a1577d591fac4d2d57b6d2e50e975 Mon Sep 17 00:00:00 2001 From: Gaurav17Joshi Date: Mon, 12 Jun 2023 18:04:07 +0530 Subject: [PATCH 24/50] Adding_gp_feature --- docs/changes/739.feature.rst | 1 + setup.cfg | 6 + stingray/modeling/gpmodeling.py | 426 ++++++++++++++++++++++++++++++++ 3 files changed, 433 insertions(+) create mode 100644 docs/changes/739.feature.rst create mode 100644 stingray/modeling/gpmodeling.py diff --git a/docs/changes/739.feature.rst b/docs/changes/739.feature.rst new file mode 100644 index 000000000..b1661b37c --- /dev/null +++ b/docs/changes/739.feature.rst @@ -0,0 +1 @@ +A feature dealing with Gaussian Processes for Qpo analysis \ No newline at end of file diff --git a/setup.cfg b/setup.cfg index 5d19b487d..e15312b2e 100644 --- a/setup.cfg +++ b/setup.cfg @@ -22,6 +22,12 @@ install_requires = scipy>=1.1.0 ; Matplotlib 3.4.0 is incompatible with Astropy matplotlib>=3.0,!=3.4.0 + jax + tinygp + jaxns + etils + tensorflow_probability + typing_extensions [options.entry_points] console_scripts = diff --git a/stingray/modeling/gpmodeling.py b/stingray/modeling/gpmodeling.py new file mode 100644 index 000000000..1c2917a28 --- /dev/null +++ b/stingray/modeling/gpmodeling.py @@ -0,0 +1,426 @@ +import numpy as np +import matplotlib.pyplot as plt +import jax +import jax.numpy as jnp +import functools +import tensorflow_probability.substrates.jax as tfp + +from jax import jit, random + +from tinygp import GaussianProcess, kernels +from stingray import Lightcurve + +from jaxns import ExactNestedSampler +from jaxns import TerminationCondition + +# from jaxns import analytic_log_evidence +from jaxns import Prior, Model + +jax.config.update("jax_enable_x64", True) + +tfpd = tfp.distributions +tfpb = tfp.bijectors + +__all__ = ["GP", "GPResult"] + + +def get_kernel(kernel_type, kernel_params): + """ + Function for producing the kernel for the Gaussian Process. + Returns the selected Tinygp kernel + + Parameters + ---------- + kernel_type: string + The type of kernel to be used for the Gaussian Process + To be selected from the kernels already implemented + + kernel_params: dict + Dictionary containing the parameters for the kernel + Should contain the parameters for the selected kernel + + """ + if kernel_type == "QPO_plus_RN": + kernel = kernels.quasisep.Exp( + scale=1 / kernel_params["crn"], sigma=(kernel_params["arn"]) ** 0.5 + ) + kernels.quasisep.Celerite( + a=kernel_params["aqpo"], + b=0.0, + c=kernel_params["cqpo"], + d=2 * jnp.pi * kernel_params["freq"], + ) + return kernel + elif kernel_type == "RN": + kernel = kernels.quasisep.Exp( + scale=1 / kernel_params["crn"], sigma=(kernel_params["arn"]) ** 0.5 + ) + return kernel + + +def get_mean(mean_type, mean_params): + """ + Function for producing the mean for the Gaussian Process. + + Parameters + ---------- + mean_type: string + The type of mean to be used for the Gaussian Process + To be selected from the mean functions already implemented + + mean_params: dict + Dictionary containing the parameters for the mean + Should contain the parameters for the selected mean + + """ + if mean_type == "gaussian": + mean = functools.partial(_gaussian, mean_params=mean_params) + elif mean_type == "exponential": + mean = functools.partial(_exponential, mean_params=mean_params) + elif mean_type == "constant": + mean = functools.partial(_constant, mean_params=mean_params) + return mean + + +def _gaussian(t, mean_params): + return mean_params["A"] * jnp.exp( + -((t - mean_params["t0"]) ** 2) / (2 * (mean_params["sig"] ** 2)) + ) + + +def _exponential(t, mean_params): + return mean_params["A"] * jnp.exp(-jnp.abs((t - mean_params["t0"])) / mean_params["sig"]) + + +def _constant(t, mean_params): + return mean_params["A"] * jnp.ones_like(t) + + +class GP: + """ + Makes a GP object which takes in a Stingray.Lightcurve and fits a Gaussian + Process on the lightcurve data, for the given kernel. + + Parameters + ---------- + lc: Stingray.Lightcurve object + The lightcurve on which the gaussian process, is to be fitted + + Model_type: string tuple + Has two strings with the first being the name of the kernel type + and the secound being the mean type + + Model_parameter: dict, default = None + Dictionary conatining the parameters for the mean and kernel + The keys should be accourding to the selected kernel and mean + coressponding to the Model_type + By default, it takes a value None, and the kernel and mean are + then bulit using the pre-set parameters. + + Other Parameters + ---------------- + kernel: class: `TinyGp.kernel` object + The tinygp kernel for the GP + + mean: class: `TinyGp.mean` object + The tinygp mean for the GP + + maingp: class: `TinyGp.GaussianProcess` object + The tinygp gaussian process made on the lightcurve + + """ + + def __init__(self, Lc: Lightcurve, Model_type: tuple, Model_params: dict = None) -> None: + self.lc = Lc + self.Model_type = Model_type + self.Model_param = Model_params + self.kernel = get_kernel(self.Model_type[0], self.Model_param) + self.mean = get_mean(self.Model_type[1], self.Model_param) + self.maingp = GaussianProcess( + self.kernel, Lc.time, mean=self.mean, diag=Model_params["diag"] + ) + + def get_logprob(self): + """ + Returns the logprobability of the lightcurves counts for the + given kernel for the Gaussian Process + """ + cond = self.maingp.condition(self.lc.counts) + return cond.log_probability + + def get_model(self): + """ + Returns the model of the Gaussian Process + """ + return (self.Model_type, self.Model_param) + + def plot_kernel(self): + """ + Plots the kernel of the Gaussian Process + """ + X = self.lc.time + Y = self.kernel(X, np.array([0.0])) + plt.plot(X, Y) + plt.xlabel("distance") + plt.ylabel("Value") + plt.title("Kernel Function") + + def plot_originalgp(self, sample_no=1, seed=0): + """ + Plots samples obtained from the gaussian process for the kernel + + Parameters + ---------- + sample_no: int , default = 1 + Number of GP samples to be taken + + """ + X_test = self.lc.time + _, ax = plt.subplots(1, 1, figsize=(10, 3)) + y_samp = self.maingp.sample(jax.random.PRNGKey(seed), shape=(sample_no,)) + ax.plot(X_test, y_samp[0], "C0", lw=0.5, alpha=0.5, label="samples") + ax.plot(X_test, y_samp[1:].T, "C0", lw=0.5, alpha=0.5) + ax.set_xlabel("time") + ax.set_ylabel("counts") + ax.legend(loc="best") + + def plot_gp(self, sample_no=1, seed=0): + """ + Plots gaussian process, conditioned on the lightcurve + Also, plots the lightcurve along with it + + Parameters + ---------- + sample_no: int , default = 1 + Number of GP samples to be taken + + """ + X_test = self.lc.time + + _, ax = plt.subplots(1, 1, figsize=(10, 3)) + _, cond_gp = self.maingp.condition(self.lc.counts, X_test) + mu = cond_gp.mean + # std = np.sqrt(cond_gp.variance) + + ax.plot(self.lc.time, self.lc.counts, lw=2, color="blue", label="Lightcurve") + ax.plot(X_test, mu, "C1", label="Gaussian Process") + y_samp = cond_gp.sample(jax.random.PRNGKey(seed), shape=(sample_no,)) + ax.plot(X_test, y_samp[0], "C0", lw=0.5, alpha=0.5) + ax.set_xlabel("time") + ax.set_ylabel("counts") + ax.legend(loc="best") + + +def get_prior(kernel_type, mean_type, **kwargs): + """ + A prior generator function based on given values + + Parameters + ---------- + kwargs: + All possible keyword arguments to construct the prior. + + Returns + ------- + The Prior function. + The arguments of the prior function are in the order of + Kernel arguments (RN arguments, QPO arguments), + Mean arguments + Non Windowed arguments + + """ + kwargs["T"] = kwargs["Times"][-1] - kwargs["Times"][0] # Total time + kwargs["f"] = 1 / (kwargs["Times"][1] - kwargs["Times"][0]) # Sampling frequency + kwargs["min"] = jnp.min(kwargs["counts"]) + kwargs["max"] = jnp.max(kwargs["counts"]) + kwargs["span"] = kwargs["max"] - kwargs["min"] + + def RNprior_model(): + arn = yield Prior(tfpd.Uniform(0.1 * kwargs["span"], 2 * kwargs["span"]), name="arn") + crn = yield Prior(tfpd.Uniform(jnp.log(1 / kwargs["T"]), jnp.log(kwargs["f"])), name="crn") + + A = yield Prior(tfpd.Uniform(0.1 * kwargs["span"], 2 * kwargs["span"]), name="A") + t0 = yield Prior( + tfpd.Uniform( + kwargs["Times"][0] - 0.1 * kwargs["T"], kwargs["Times"][-1] + 0.1 * kwargs["T"] + ), + name="t0", + ) + sig = yield Prior(tfpd.Uniform(0.5 * 1 / kwargs["f"], 2 * kwargs["T"]), name="sig") + return arn, crn, A, t0, sig + + if (kernel_type == "RN") & ((mean_type == "gaussian") | (mean_type == "exponential")): + return RNprior_model + + def QPOprior_model(): + arn = yield Prior(tfpd.Uniform(0.1 * kwargs["span"], 2 * kwargs["span"]), name="arn") + crn = yield Prior(tfpd.Uniform(jnp.log(1 / kwargs["T"]), jnp.log(kwargs["f"])), name="crn") + aqpo = yield Prior(tfpd.Uniform(0.1 * kwargs["span"], 2 * kwargs["span"]), name="aqpo") + cqpo = yield Prior(tfpd.Uniform(1 / 10 / kwargs["T"], jnp.log(kwargs["f"])), name="cqpo") + freq = yield Prior(tfpd.Uniform(2 / kwargs["T"], kwargs["f"] / 2), name="freq") + + A = yield Prior(tfpd.Uniform(0.1 * kwargs["span"], 2 * kwargs["span"]), name="A") + t0 = yield Prior( + tfpd.Uniform( + kwargs["Times"][0] - 0.1 * kwargs["T"], kwargs["Times"][-1] + 0.1 * kwargs["T"] + ), + name="t0", + ) + sig = yield Prior(tfpd.Uniform(0.5 * 1 / kwargs["f"], 2 * kwargs["T"]), name="sig") + + return arn, crn, aqpo, cqpo, freq, A, t0, sig + + if (kernel_type == "QPO_plus_RN") & ((mean_type == "gaussian") | (mean_type == "exponential")): + return QPOprior_model + + +def get_likelihood(kernel_type, mean_type, **kwargs): + """ + A likelihood generator function based on given values + """ + + @jit + def RNlog_likelihood(arn, crn, A, t0, sig): + rnlikelihood_params = { + "arn": arn, + "crn": crn, + "aqpo": 0.0, + "cqpo": 0.0, + "freq": 0.0, + } + + mean_params = { + "A": A, + "t0": t0, + "sig": sig, + } + + kernel = get_kernel(kernel_type="RN", kernel_params=rnlikelihood_params) + + mean = get_mean(mean_type=mean_type, mean_params=mean_params) + + gp = GaussianProcess(kernel, kwargs["Times"], mean=mean) + return gp.log_probability(kwargs["counts"]) + + if (kernel_type == "RN") & ((mean_type == "gaussian") | (mean_type == "exponential")): + return RNlog_likelihood + + @jit + def QPOlog_likelihood(arn, crn, aqpo, cqpo, freq, A, t0, sig): + qpolikelihood_params = { + "arn": arn, + "crn": crn, + "aqpo": aqpo, + "cqpo": cqpo, + "freq": freq, + } + + mean_params = { + "A": A, + "t0": t0, + "sig": sig, + } + + kernel = get_kernel(kernel_type="RN", kernel_params=qpolikelihood_params) + mean = get_mean(mean_type=mean_type, mean_params=mean_params) + + gp = GaussianProcess(kernel, kwargs["Times"], mean=mean) + return gp.log_probability(kwargs["counts"]) + + if (kernel_type == "QPO_plus_RN") & ((mean_type == "gaussian") | (mean_type == "exponential")): + return QPOlog_likelihood + + +class GPResult: + """ + Makes a GP regressor for a given GP class and a prior over it. + Provides the sampled hyperparameters and tabulates their charachtersistics + Using jaxns for nested sampling and evidence analysis + + Parameters + ---------- + GP: class: GP + The initial GP class, on which we will apply our regressor. + + prior_type: string tuple + Has two strings with the first being the name of the kernel type + and the secound being the mean type for the prior + + prior_parameters: dict, default = None + Dictionary containing the parameters for the mean and kernel priors + The keys should be accourding to the selected kernel and mean + prior coressponding to the prior_type + By default, it takes a value None, and the kernel and mean priors are + then bulit using the pre-set parameters. + + Other Parameters + ---------------- + lc: Stingray.Lightcurve object + The lightcurve on which the gaussian process regression, is to be done + + """ + + def __init__(self, GP: GP, prior_type: tuple, prior_parameters=None) -> None: + self.gpclass = GP + self.prior_type = prior_type + self.prior_parameters = prior_parameters + self.lc = GP.lc + + def run_sampling(self): + """ + Runs a sampling process for the hyperparameters for the GP model. + Based on No U turn Sampling from the numpyro module + """ + + dict = {"Times": self.lc.time, "counts": self.lc.counts} + self.prior_model = get_prior(self.prior_type[0], self.prior_type[1], **dict) + self.likelihood = get_likelihood(self.prior_type[0], self.prior_type[1], **dict) + + NSmodel = Model(prior_model=self.prior_model, log_likelihood=self.likelihood) + + NSmodel.sanity_check(random.PRNGKey(10), S=100) + + self.Exact_ns = ExactNestedSampler(NSmodel, num_live_points=500, max_samples=1e4) + Termination_reason, State = self.Exact_ns( + random.PRNGKey(42), term_cond=TerminationCondition(live_evidence_frac=1e-4) + ) + self.Results = self.Exact_ns.to_results(State, Termination_reason) + + def print_summary(self): + """ + Prints a summary table for the model parameters + """ + self.Exact_ns.summary(self.Results) + + def plot_diagnostics(self): + """ + Plots the diagnostic plots for the sampling process + """ + self.Exact_ns.plot_diagnostics(self.Results) + + def corner_plot(self): + """ + Plots the corner plot for the sampled hyperparameters + """ + self.Exact_ns.plot_corner(self.Results) + + def get_parameters(self): + """ + Returns the optimal parameters for the model based on the NUTS sampling + """ + + pass + + def plot_posterior(self, X_test): + """ + Plots posterior gaussian process, conditioned on the lightcurve + Also, plots the lightcurve along with it + + Parameters + ---------- + X_test: jnp.array + Array over which the Gaussian process values are to be obtained + Can be made default with lc.times as default + + """ + + pass From 647c72ef422c2f10bd35c510f70bcbbd49302878 Mon Sep 17 00:00:00 2001 From: Gaurav17Joshi Date: Wed, 14 Jun 2023 18:43:55 +0530 Subject: [PATCH 25/50] Added skew means --- stingray/modeling/gpmodeling.py | 212 +++++++++++++++++++++++++++++++- 1 file changed, 209 insertions(+), 3 deletions(-) diff --git a/stingray/modeling/gpmodeling.py b/stingray/modeling/gpmodeling.py index 1c2917a28..8f8f0793a 100644 --- a/stingray/modeling/gpmodeling.py +++ b/stingray/modeling/gpmodeling.py @@ -78,23 +78,128 @@ def get_mean(mean_type, mean_params): mean = functools.partial(_exponential, mean_params=mean_params) elif mean_type == "constant": mean = functools.partial(_constant, mean_params=mean_params) + elif mean_type == "skew_gaussian": + mean = functools.partial(_skew_gaussian, mean_params=mean_params) + elif mean_type == "skew_exponential": + mean = functools.partial(_skew_exponential, mean_params=mean_params) return mean def _gaussian(t, mean_params): + """A gaussian flare shape. + + Parameters + ---------- + t: jnp.ndarray + The time coordinates. + A: jnp.int + Amplitude of the flare. + t0: + The location of the maximum. + sig1: + The width parameter for the gaussian. + + Returns + ------- + The y values for the gaussian flare. + """ return mean_params["A"] * jnp.exp( -((t - mean_params["t0"]) ** 2) / (2 * (mean_params["sig"] ** 2)) ) def _exponential(t, mean_params): + """An exponential flare shape. + + Parameters + ---------- + t: jnp.ndarray + The time coordinates. + A: jnp.int + Amplitude of the flare. + t0: + The location of the maximum. + sig1: + The width parameter for the exponential. + + Returns + ------- + The y values for exponential flare. + """ return mean_params["A"] * jnp.exp(-jnp.abs((t - mean_params["t0"])) / mean_params["sig"]) def _constant(t, mean_params): + """A constant mean shape. + + Parameters + ---------- + t: jnp.ndarray + The time coordinates. + A: jnp.int + Constant amplitude of the flare. + + Returns + ------- + The constant value. + """ return mean_params["A"] * jnp.ones_like(t) +def _skew_gaussian(t, mean_params): + """A skew gaussian flare shape. + + Parameters + ---------- + t: jnp.ndarray + The time coordinates. + A: jnp.int + Amplitude of the flare. + t0: + The location of the maximum. + sig1: + The width parameter for the rising edge. + sig2: + The width parameter for the falling edge. + + Returns + ------- + The y values for skew gaussian flare. + """ + return mean_params["A"] * jnp.where( + t > mean_params["t0"], + jnp.exp(-((t - mean_params["t0"]) ** 2) / (2 * (mean_params["sig2"] ** 2))), + jnp.exp(-((t - mean_params["t0"]) ** 2) / (2 * (mean_params["sig1"] ** 2))), + ) + + +def _skew_exponential(t, mean_params): + """A skew exponential flare shape. + + Parameters + ---------- + t: jnp.ndarray + The time coordinates. + A: jnp.int + Amplitude of the flare. + t0: + The location of the maximum. + sig1: + The width parameter for the rising edge. + sig2: + The width parameter for the falling edge. + + Returns + ------- + The y values for exponential flare. + """ + return mean_params["A"] * jnp.where( + t > mean_params["t0"], + jnp.exp(-(t - mean_params["t0"]) / mean_params["sig2"]), + jnp.exp((t - mean_params["t0"]) / mean_params["sig1"]), + ) + + class GP: """ Makes a GP object which takes in a Stingray.Lightcurve and fits a Gaussian @@ -272,6 +377,49 @@ def QPOprior_model(): if (kernel_type == "QPO_plus_RN") & ((mean_type == "gaussian") | (mean_type == "exponential")): return QPOprior_model + def skew_RNprior_model(): + arn = yield Prior(tfpd.Uniform(0.1 * kwargs["span"], 2 * kwargs["span"]), name="arn") + crn = yield Prior(tfpd.Uniform(jnp.log(1 / kwargs["T"]), jnp.log(kwargs["f"])), name="crn") + + A = yield Prior(tfpd.Uniform(0.1 * kwargs["span"], 2 * kwargs["span"]), name="A") + t0 = yield Prior( + tfpd.Uniform( + kwargs["Times"][0] - 0.1 * kwargs["T"], kwargs["Times"][-1] + 0.1 * kwargs["T"] + ), + name="t0", + ) + sig1 = yield Prior(tfpd.Uniform(0.5 * 1 / kwargs["f"], 2 * kwargs["T"]), name="sig1") + sig2 = yield Prior(tfpd.Uniform(0.5 * 1 / kwargs["f"], 2 * kwargs["T"]), name="sig2") + + return arn, crn, A, t0, sig1, sig2 + + if (kernel_type == "RN") & ((mean_type == "skew_gaussian") | (mean_type == "skew_exponential")): + return skew_RNprior_model + + def skew_QPOprior_model(): + arn = yield Prior(tfpd.Uniform(0.1 * kwargs["span"], 2 * kwargs["span"]), name="arn") + crn = yield Prior(tfpd.Uniform(jnp.log(1 / kwargs["T"]), jnp.log(kwargs["f"])), name="crn") + aqpo = yield Prior(tfpd.Uniform(0.1 * kwargs["span"], 2 * kwargs["span"]), name="aqpo") + cqpo = yield Prior(tfpd.Uniform(1 / 10 / kwargs["T"], jnp.log(kwargs["f"])), name="cqpo") + freq = yield Prior(tfpd.Uniform(2 / kwargs["T"], kwargs["f"] / 2), name="freq") + + A = yield Prior(tfpd.Uniform(0.1 * kwargs["span"], 2 * kwargs["span"]), name="A") + t0 = yield Prior( + tfpd.Uniform( + kwargs["Times"][0] - 0.1 * kwargs["T"], kwargs["Times"][-1] + 0.1 * kwargs["T"] + ), + name="t0", + ) + sig1 = yield Prior(tfpd.Uniform(0.5 * 1 / kwargs["f"], 2 * kwargs["T"]), name="sig1") + sig2 = yield Prior(tfpd.Uniform(0.5 * 1 / kwargs["f"], 2 * kwargs["T"]), name="sig2") + + return arn, crn, aqpo, cqpo, freq, A, t0, sig1, sig2 + + if (kernel_type == "QPO_plus_RN") & ( + (mean_type == "skew_gaussian") | (mean_type == "skew_exponential") + ): + return skew_QPOprior_model + def get_likelihood(kernel_type, mean_type, **kwargs): """ @@ -320,7 +468,7 @@ def QPOlog_likelihood(arn, crn, aqpo, cqpo, freq, A, t0, sig): "sig": sig, } - kernel = get_kernel(kernel_type="RN", kernel_params=qpolikelihood_params) + kernel = get_kernel(kernel_type="QPO_plus_RN", kernel_params=qpolikelihood_params) mean = get_mean(mean_type=mean_type, mean_params=mean_params) gp = GaussianProcess(kernel, kwargs["Times"], mean=mean) @@ -329,6 +477,63 @@ def QPOlog_likelihood(arn, crn, aqpo, cqpo, freq, A, t0, sig): if (kernel_type == "QPO_plus_RN") & ((mean_type == "gaussian") | (mean_type == "exponential")): return QPOlog_likelihood + @jit + def skew_RNlog_likelihood(arn, crn, A, t0, sig1, sig2): + rnlikelihood_params = { + "arn": arn, + "crn": crn, + "aqpo": 0.0, + "cqpo": 0.0, + "freq": 0.0, + } + + mean_params = { + "A": A, + "t0": t0, + "sig1": sig1, + "sig2": sig2, + } + + kernel = get_kernel(kernel_type="RN", kernel_params=rnlikelihood_params) + + # This could be causing problems + mean = get_mean(mean_type=mean_type, mean_params=mean_params) + + gp = GaussianProcess(kernel, kwargs["Times"], mean=mean) + return gp.log_probability(kwargs["counts"]) + + if (kernel_type == "RN") & ((mean_type == "gaussian") | (mean_type == "exponential")): + return skew_RNlog_likelihood + + @jit + def skewQPOlog_likelihood(arn, crn, aqpo, cqpo, freq, A, t0, sig1, sig2): + qpolikelihood_params = { + "arn": arn, + "crn": crn, + "aqpo": aqpo, + "cqpo": cqpo, + "freq": freq, + } + + mean_params = { + "A": A, + "t0": t0, + "sig1": sig1, + "sig2": sig2, + } + + kernel = get_kernel(kernel_type="QPO_plus_RN", kernel_params=qpolikelihood_params) + + mean = get_mean(mean_type=mean_type, mean_params=mean_params) + + gp = GaussianProcess(kernel, kwargs["Times"], mean=mean) + return gp.log_probability(kwargs["counts"]) + + if (kernel_type == "QPO_plus_RN") & ( + (mean_type == "skew_gaussian") | (mean_type == "skew_exponential") + ): + return skewQPOlog_likelihood + class GPResult: """ @@ -384,6 +589,7 @@ def run_sampling(self): random.PRNGKey(42), term_cond=TerminationCondition(live_evidence_frac=1e-4) ) self.Results = self.Exact_ns.to_results(State, Termination_reason) + print("Simulation Complete") def print_summary(self): """ @@ -397,11 +603,11 @@ def plot_diagnostics(self): """ self.Exact_ns.plot_diagnostics(self.Results) - def corner_plot(self): + def plot_cornerplot(self): """ Plots the corner plot for the sampled hyperparameters """ - self.Exact_ns.plot_corner(self.Results) + self.Exact_ns.plot_cornerplot(self.Results) def get_parameters(self): """ From 0e091ae8c375dc7a5846d2bb7e3349732afa48cd Mon Sep 17 00:00:00 2001 From: Gaurav17Joshi Date: Mon, 19 Jun 2023 21:17:34 +0530 Subject: [PATCH 26/50] Added fred model --- stingray/modeling/gpmodeling.py | 133 ++++++++++++++++++++++++++++++++ 1 file changed, 133 insertions(+) diff --git a/stingray/modeling/gpmodeling.py b/stingray/modeling/gpmodeling.py index 8f8f0793a..5b04ff120 100644 --- a/stingray/modeling/gpmodeling.py +++ b/stingray/modeling/gpmodeling.py @@ -82,6 +82,8 @@ def get_mean(mean_type, mean_params): mean = functools.partial(_skew_gaussian, mean_params=mean_params) elif mean_type == "skew_exponential": mean = functools.partial(_skew_exponential, mean_params=mean_params) + elif mean_type == "fred": + mean = functools.partial(_fred, mean_params=mean_params) return mean @@ -200,6 +202,39 @@ def _skew_exponential(t, mean_params): ) +def _fred(t, mean_params): + """A fast rise exponential decay (FRED) flare shape. + + Parameters + ---------- + t: jnp.ndarray + The time coordinates. + A: jnp.int + Amplitude of the flare. + t0: + The location of the maximum. + phi: + Symmetry parameter of the flare. + delta: + Offset parameter of the flare. + + Returns + ------- + The y values for exponential flare. + """ + return ( + mean_params["A"] + * jnp.exp( + -mean_params["phi"] + * ( + (t + mean_params["delta"]) / mean_params["t0"] + + mean_params["t0"] / (t + mean_params["delta"]) + ) + ) + * jnp.exp(2 * mean_params["phi"]) + ) + + class GP: """ Makes a GP object which takes in a Stingray.Lightcurve and fits a Gaussian @@ -420,6 +455,47 @@ def skew_QPOprior_model(): ): return skew_QPOprior_model + def fred_RNprior_model(): + arn = yield Prior(tfpd.Uniform(0.1 * kwargs["span"], 2 * kwargs["span"]), name="arn") + crn = yield Prior(tfpd.Uniform(jnp.log(1 / kwargs["T"]), jnp.log(kwargs["f"])), name="crn") + + A = yield Prior(tfpd.Uniform(0.1 * kwargs["span"], 2 * kwargs["span"]), name="A") + t0 = yield Prior( + tfpd.Uniform( + kwargs["Times"][0] - 0.1 * kwargs["T"], kwargs["Times"][-1] + 0.1 * kwargs["T"] + ), + name="t0", + ) + phi = yield Prior(tfpd.Uniform(2 * jnp.exp(-2), 2 * jnp.exp(4)), name="phi") + delta = yield Prior(tfpd.Uniform(0, kwargs["Times"][-1] / 2), name="delta") + + return arn, crn, A, t0, phi, delta + + if (kernel_type == "RN") & (mean_type == "fred"): + return fred_RNprior_model + + def fred_QPOprior_model(): + arn = yield Prior(tfpd.Uniform(0.1 * kwargs["span"], 2 * kwargs["span"]), name="arn") + crn = yield Prior(tfpd.Uniform(jnp.log(1 / kwargs["T"]), jnp.log(kwargs["f"])), name="crn") + aqpo = yield Prior(tfpd.Uniform(0.1 * kwargs["span"], 2 * kwargs["span"]), name="aqpo") + cqpo = yield Prior(tfpd.Uniform(1 / 10 / kwargs["T"], jnp.log(kwargs["f"])), name="cqpo") + freq = yield Prior(tfpd.Uniform(2 / kwargs["T"], kwargs["f"] / 2), name="freq") + + A = yield Prior(tfpd.Uniform(0.1 * kwargs["span"], 2 * kwargs["span"]), name="A") + t0 = yield Prior( + tfpd.Uniform( + kwargs["Times"][0] - 0.1 * kwargs["T"], kwargs["Times"][-1] + 0.1 * kwargs["T"] + ), + name="t0", + ) + phi = yield Prior(tfpd.Uniform(2 * jnp.exp(-2), 2 * jnp.exp(4)), name="phi") + delta = yield Prior(tfpd.Uniform(0, kwargs["Times"][-1] / 2), name="delta") + + return arn, crn, aqpo, cqpo, freq, A, t0, phi, delta + + if (kernel_type == "QPO_plus_RN") & (mean_type == "fred"): + return fred_QPOprior_model + def get_likelihood(kernel_type, mean_type, **kwargs): """ @@ -534,6 +610,63 @@ def skewQPOlog_likelihood(arn, crn, aqpo, cqpo, freq, A, t0, sig1, sig2): ): return skewQPOlog_likelihood + @jit + def fred_RNlog_likelihood(arn, crn, A, t0, phi, delta): + rnlikelihood_params = { + "arn": arn, + "crn": crn, + "aqpo": 0.0, + "cqpo": 0.0, + "freq": 0.0, + } + + mean_params = { + "A": A, + "t0": t0, + "phi": phi, + "delta": delta, + } + + kernel = get_kernel(kernel_type="RN", kernel_params=rnlikelihood_params) + + # This could be causing problems + mean = get_mean(mean_type=mean_type, mean_params=mean_params) + + # gp = GaussianProcess(kernel, kwargs["Times"], mean=mean) + gp = GaussianProcess(kernel, kwargs["Times"], mean_value=mean(kwargs["Times"])) + return gp.log_probability(kwargs["counts"]) + + if (kernel_type == "RN") & (mean_type == "fred"): + return fred_RNlog_likelihood + + @jit + def fredQPOlog_likelihood(arn, crn, aqpo, cqpo, freq, A, t0, phi, delta): + qpolikelihood_params = { + "arn": arn, + "crn": crn, + "aqpo": aqpo, + "cqpo": cqpo, + "freq": freq, + } + + mean_params = { + "A": A, + "t0": t0, + "phi": phi, + "delta": delta, + } + + kernel = get_kernel(kernel_type="QPO_plus_RN", kernel_params=qpolikelihood_params) + + mean = get_mean(mean_type=mean_type, mean_params=mean_params) + + # gp = GaussianProcess(kernel, kwargs["Times"], mean=mean) + gp = GaussianProcess(kernel, kwargs["Times"], mean_value=mean(kwargs["Times"])) + return gp.log_probability(kwargs["counts"]) + + if (kernel_type == "QPO_plus_RN") & (mean_type == "fred"): + return fredQPOlog_likelihood + class GPResult: """ From dbaee735e8f5a10c7b9c1c167342d3bbb2191198 Mon Sep 17 00:00:00 2001 From: Gaurav17Joshi Date: Tue, 27 Jun 2023 16:05:57 +0530 Subject: [PATCH 27/50] Combined GP, GPR class --- stingray/modeling/gpmodeling.py | 492 ++++++++------------------------ 1 file changed, 120 insertions(+), 372 deletions(-) diff --git a/stingray/modeling/gpmodeling.py b/stingray/modeling/gpmodeling.py index 5b04ff120..ab810caf3 100644 --- a/stingray/modeling/gpmodeling.py +++ b/stingray/modeling/gpmodeling.py @@ -21,7 +21,7 @@ tfpd = tfp.distributions tfpb = tfp.bijectors -__all__ = ["GP", "GPResult"] +__all__ = ["GP"] def get_kernel(kernel_type, kernel_params): @@ -235,6 +235,99 @@ def _fred(t, mean_params): ) +def get_kernel_params(kernel_type): + if kernel_type == "RN": + return ["arn", "crn"] + elif kernel_type == "QPO_plus_RN": + return ["arn", "crn", "aqpo", "cqpo", "freq"] + + +def get_mean_params(mean_type): + if (mean_type == "gaussian") or (mean_type == "exponential"): + return ["A", "t0", "sig"] + elif mean_type == "constant": + return ["A"] + elif (mean_type == "skew_gaussian") or (mean_type == "skew_exponential"): + return ["A", "t0", "sig1", "sig2"] + elif mean_type == "fred": + return ["A", "t0", "delta", "phi"] + + +def get_gp_params(kernel_type, mean_type): + kernel_params = get_kernel_params(kernel_type) + mean_params = get_mean_params(mean_type) + kernel_params.extend(mean_params) + return kernel_params + + +def get_prior(params_list, prior_dict): + """ + A prior generator function based on given values + + Parameters + ---------- + params_list: + A list in order of the parameters to be used. + + prior_dict: + A dictionary of the priors of parameters to be used. + + Returns + ------- + The Prior function. + The arguments of the prior function are in the order of + Kernel arguments (RN arguments, QPO arguments), + Mean arguments + Non Windowed arguments + + """ + + def prior_model(): + prior_list = [] + for i in params_list: + if isinstance(prior_dict[i], tfpd.Distribution): + parameter = yield Prior(prior_dict[i], name=i) + else: + parameter = yield prior_dict[i] + prior_list.append(parameter) + return tuple(prior_list) + + return prior_model + + +def get_likelihood(params_list, kernel_type, mean_type, **kwargs): + """ + A likelihood generator function based on given values + + Parameters + ---------- + params_list: + A list in order of the parameters to be used. + + prior_dict: + A dictionary of the priors of parameters to be used. + + kernel_type: + The type of kernel to be used in the model. + + mean_type: + The type of mean to be used in the model. + + """ + + @jit + def likelihood_model(*args): + dict = {} + for i, params in enumerate(params_list): + dict[params] = args[i] + kernel = get_kernel(kernel_type=kernel_type, kernel_params=dict) + mean = get_mean(mean_type=mean_type, mean_params=dict) + gp = GaussianProcess(kernel, kwargs["Times"], mean_value=mean(kwargs["Times"])) + return gp.log_probability(kwargs["counts"]) + + return likelihood_model + + class GP: """ Makes a GP object which takes in a Stingray.Lightcurve and fits a Gaussian @@ -269,14 +362,16 @@ class GP: """ - def __init__(self, Lc: Lightcurve, Model_type: tuple, Model_params: dict = None) -> None: + def __init__(self, Lc: Lightcurve) -> None: self.lc = Lc - self.Model_type = Model_type - self.Model_param = Model_params - self.kernel = get_kernel(self.Model_type[0], self.Model_param) - self.mean = get_mean(self.Model_type[1], self.Model_param) + self.time = Lc.time + self.counts = Lc.counts + + def fit(self, kernel=None, mean=None, **kwargs): + self.kernel = kernel + self.mean = mean self.maingp = GaussianProcess( - self.kernel, Lc.time, mean=self.mean, diag=Model_params["diag"] + self.kernel, self.time, mean_value=self.mean(self.time), diag=kwargs["diag"] ) def get_logprob(self): @@ -287,12 +382,6 @@ def get_logprob(self): cond = self.maingp.condition(self.lc.counts) return cond.log_probability - def get_model(self): - """ - Returns the model of the Gaussian Process - """ - return (self.Model_type, self.Model_param) - def plot_kernel(self): """ Plots the kernel of the Gaussian Process @@ -349,372 +438,31 @@ def plot_gp(self, sample_no=1, seed=0): ax.set_ylabel("counts") ax.legend(loc="best") + def sample(self, prior_model=None, likelihood_model=None, **kwargs): + """ + Makes a Jaxns nested sampler over the Gaussian Process, given the + prior and likelihood model -def get_prior(kernel_type, mean_type, **kwargs): - """ - A prior generator function based on given values - - Parameters - ---------- - kwargs: - All possible keyword arguments to construct the prior. - - Returns - ------- - The Prior function. - The arguments of the prior function are in the order of - Kernel arguments (RN arguments, QPO arguments), - Mean arguments - Non Windowed arguments - - """ - kwargs["T"] = kwargs["Times"][-1] - kwargs["Times"][0] # Total time - kwargs["f"] = 1 / (kwargs["Times"][1] - kwargs["Times"][0]) # Sampling frequency - kwargs["min"] = jnp.min(kwargs["counts"]) - kwargs["max"] = jnp.max(kwargs["counts"]) - kwargs["span"] = kwargs["max"] - kwargs["min"] - - def RNprior_model(): - arn = yield Prior(tfpd.Uniform(0.1 * kwargs["span"], 2 * kwargs["span"]), name="arn") - crn = yield Prior(tfpd.Uniform(jnp.log(1 / kwargs["T"]), jnp.log(kwargs["f"])), name="crn") - - A = yield Prior(tfpd.Uniform(0.1 * kwargs["span"], 2 * kwargs["span"]), name="A") - t0 = yield Prior( - tfpd.Uniform( - kwargs["Times"][0] - 0.1 * kwargs["T"], kwargs["Times"][-1] + 0.1 * kwargs["T"] - ), - name="t0", - ) - sig = yield Prior(tfpd.Uniform(0.5 * 1 / kwargs["f"], 2 * kwargs["T"]), name="sig") - return arn, crn, A, t0, sig - - if (kernel_type == "RN") & ((mean_type == "gaussian") | (mean_type == "exponential")): - return RNprior_model - - def QPOprior_model(): - arn = yield Prior(tfpd.Uniform(0.1 * kwargs["span"], 2 * kwargs["span"]), name="arn") - crn = yield Prior(tfpd.Uniform(jnp.log(1 / kwargs["T"]), jnp.log(kwargs["f"])), name="crn") - aqpo = yield Prior(tfpd.Uniform(0.1 * kwargs["span"], 2 * kwargs["span"]), name="aqpo") - cqpo = yield Prior(tfpd.Uniform(1 / 10 / kwargs["T"], jnp.log(kwargs["f"])), name="cqpo") - freq = yield Prior(tfpd.Uniform(2 / kwargs["T"], kwargs["f"] / 2), name="freq") - - A = yield Prior(tfpd.Uniform(0.1 * kwargs["span"], 2 * kwargs["span"]), name="A") - t0 = yield Prior( - tfpd.Uniform( - kwargs["Times"][0] - 0.1 * kwargs["T"], kwargs["Times"][-1] + 0.1 * kwargs["T"] - ), - name="t0", - ) - sig = yield Prior(tfpd.Uniform(0.5 * 1 / kwargs["f"], 2 * kwargs["T"]), name="sig") - - return arn, crn, aqpo, cqpo, freq, A, t0, sig - - if (kernel_type == "QPO_plus_RN") & ((mean_type == "gaussian") | (mean_type == "exponential")): - return QPOprior_model - - def skew_RNprior_model(): - arn = yield Prior(tfpd.Uniform(0.1 * kwargs["span"], 2 * kwargs["span"]), name="arn") - crn = yield Prior(tfpd.Uniform(jnp.log(1 / kwargs["T"]), jnp.log(kwargs["f"])), name="crn") - - A = yield Prior(tfpd.Uniform(0.1 * kwargs["span"], 2 * kwargs["span"]), name="A") - t0 = yield Prior( - tfpd.Uniform( - kwargs["Times"][0] - 0.1 * kwargs["T"], kwargs["Times"][-1] + 0.1 * kwargs["T"] - ), - name="t0", - ) - sig1 = yield Prior(tfpd.Uniform(0.5 * 1 / kwargs["f"], 2 * kwargs["T"]), name="sig1") - sig2 = yield Prior(tfpd.Uniform(0.5 * 1 / kwargs["f"], 2 * kwargs["T"]), name="sig2") - - return arn, crn, A, t0, sig1, sig2 - - if (kernel_type == "RN") & ((mean_type == "skew_gaussian") | (mean_type == "skew_exponential")): - return skew_RNprior_model - - def skew_QPOprior_model(): - arn = yield Prior(tfpd.Uniform(0.1 * kwargs["span"], 2 * kwargs["span"]), name="arn") - crn = yield Prior(tfpd.Uniform(jnp.log(1 / kwargs["T"]), jnp.log(kwargs["f"])), name="crn") - aqpo = yield Prior(tfpd.Uniform(0.1 * kwargs["span"], 2 * kwargs["span"]), name="aqpo") - cqpo = yield Prior(tfpd.Uniform(1 / 10 / kwargs["T"], jnp.log(kwargs["f"])), name="cqpo") - freq = yield Prior(tfpd.Uniform(2 / kwargs["T"], kwargs["f"] / 2), name="freq") - - A = yield Prior(tfpd.Uniform(0.1 * kwargs["span"], 2 * kwargs["span"]), name="A") - t0 = yield Prior( - tfpd.Uniform( - kwargs["Times"][0] - 0.1 * kwargs["T"], kwargs["Times"][-1] + 0.1 * kwargs["T"] - ), - name="t0", - ) - sig1 = yield Prior(tfpd.Uniform(0.5 * 1 / kwargs["f"], 2 * kwargs["T"]), name="sig1") - sig2 = yield Prior(tfpd.Uniform(0.5 * 1 / kwargs["f"], 2 * kwargs["T"]), name="sig2") - - return arn, crn, aqpo, cqpo, freq, A, t0, sig1, sig2 - - if (kernel_type == "QPO_plus_RN") & ( - (mean_type == "skew_gaussian") | (mean_type == "skew_exponential") - ): - return skew_QPOprior_model - - def fred_RNprior_model(): - arn = yield Prior(tfpd.Uniform(0.1 * kwargs["span"], 2 * kwargs["span"]), name="arn") - crn = yield Prior(tfpd.Uniform(jnp.log(1 / kwargs["T"]), jnp.log(kwargs["f"])), name="crn") - - A = yield Prior(tfpd.Uniform(0.1 * kwargs["span"], 2 * kwargs["span"]), name="A") - t0 = yield Prior( - tfpd.Uniform( - kwargs["Times"][0] - 0.1 * kwargs["T"], kwargs["Times"][-1] + 0.1 * kwargs["T"] - ), - name="t0", - ) - phi = yield Prior(tfpd.Uniform(2 * jnp.exp(-2), 2 * jnp.exp(4)), name="phi") - delta = yield Prior(tfpd.Uniform(0, kwargs["Times"][-1] / 2), name="delta") - - return arn, crn, A, t0, phi, delta - - if (kernel_type == "RN") & (mean_type == "fred"): - return fred_RNprior_model - - def fred_QPOprior_model(): - arn = yield Prior(tfpd.Uniform(0.1 * kwargs["span"], 2 * kwargs["span"]), name="arn") - crn = yield Prior(tfpd.Uniform(jnp.log(1 / kwargs["T"]), jnp.log(kwargs["f"])), name="crn") - aqpo = yield Prior(tfpd.Uniform(0.1 * kwargs["span"], 2 * kwargs["span"]), name="aqpo") - cqpo = yield Prior(tfpd.Uniform(1 / 10 / kwargs["T"], jnp.log(kwargs["f"])), name="cqpo") - freq = yield Prior(tfpd.Uniform(2 / kwargs["T"], kwargs["f"] / 2), name="freq") - - A = yield Prior(tfpd.Uniform(0.1 * kwargs["span"], 2 * kwargs["span"]), name="A") - t0 = yield Prior( - tfpd.Uniform( - kwargs["Times"][0] - 0.1 * kwargs["T"], kwargs["Times"][-1] + 0.1 * kwargs["T"] - ), - name="t0", - ) - phi = yield Prior(tfpd.Uniform(2 * jnp.exp(-2), 2 * jnp.exp(4)), name="phi") - delta = yield Prior(tfpd.Uniform(0, kwargs["Times"][-1] / 2), name="delta") - - return arn, crn, aqpo, cqpo, freq, A, t0, phi, delta - - if (kernel_type == "QPO_plus_RN") & (mean_type == "fred"): - return fred_QPOprior_model - - -def get_likelihood(kernel_type, mean_type, **kwargs): - """ - A likelihood generator function based on given values - """ - - @jit - def RNlog_likelihood(arn, crn, A, t0, sig): - rnlikelihood_params = { - "arn": arn, - "crn": crn, - "aqpo": 0.0, - "cqpo": 0.0, - "freq": 0.0, - } - - mean_params = { - "A": A, - "t0": t0, - "sig": sig, - } - - kernel = get_kernel(kernel_type="RN", kernel_params=rnlikelihood_params) - - mean = get_mean(mean_type=mean_type, mean_params=mean_params) - - gp = GaussianProcess(kernel, kwargs["Times"], mean=mean) - return gp.log_probability(kwargs["counts"]) - - if (kernel_type == "RN") & ((mean_type == "gaussian") | (mean_type == "exponential")): - return RNlog_likelihood - - @jit - def QPOlog_likelihood(arn, crn, aqpo, cqpo, freq, A, t0, sig): - qpolikelihood_params = { - "arn": arn, - "crn": crn, - "aqpo": aqpo, - "cqpo": cqpo, - "freq": freq, - } - - mean_params = { - "A": A, - "t0": t0, - "sig": sig, - } - - kernel = get_kernel(kernel_type="QPO_plus_RN", kernel_params=qpolikelihood_params) - mean = get_mean(mean_type=mean_type, mean_params=mean_params) - - gp = GaussianProcess(kernel, kwargs["Times"], mean=mean) - return gp.log_probability(kwargs["counts"]) - - if (kernel_type == "QPO_plus_RN") & ((mean_type == "gaussian") | (mean_type == "exponential")): - return QPOlog_likelihood - - @jit - def skew_RNlog_likelihood(arn, crn, A, t0, sig1, sig2): - rnlikelihood_params = { - "arn": arn, - "crn": crn, - "aqpo": 0.0, - "cqpo": 0.0, - "freq": 0.0, - } - - mean_params = { - "A": A, - "t0": t0, - "sig1": sig1, - "sig2": sig2, - } - - kernel = get_kernel(kernel_type="RN", kernel_params=rnlikelihood_params) - - # This could be causing problems - mean = get_mean(mean_type=mean_type, mean_params=mean_params) - - gp = GaussianProcess(kernel, kwargs["Times"], mean=mean) - return gp.log_probability(kwargs["counts"]) - - if (kernel_type == "RN") & ((mean_type == "gaussian") | (mean_type == "exponential")): - return skew_RNlog_likelihood - - @jit - def skewQPOlog_likelihood(arn, crn, aqpo, cqpo, freq, A, t0, sig1, sig2): - qpolikelihood_params = { - "arn": arn, - "crn": crn, - "aqpo": aqpo, - "cqpo": cqpo, - "freq": freq, - } - - mean_params = { - "A": A, - "t0": t0, - "sig1": sig1, - "sig2": sig2, - } - - kernel = get_kernel(kernel_type="QPO_plus_RN", kernel_params=qpolikelihood_params) - - mean = get_mean(mean_type=mean_type, mean_params=mean_params) - - gp = GaussianProcess(kernel, kwargs["Times"], mean=mean) - return gp.log_probability(kwargs["counts"]) - - if (kernel_type == "QPO_plus_RN") & ( - (mean_type == "skew_gaussian") | (mean_type == "skew_exponential") - ): - return skewQPOlog_likelihood - - @jit - def fred_RNlog_likelihood(arn, crn, A, t0, phi, delta): - rnlikelihood_params = { - "arn": arn, - "crn": crn, - "aqpo": 0.0, - "cqpo": 0.0, - "freq": 0.0, - } - - mean_params = { - "A": A, - "t0": t0, - "phi": phi, - "delta": delta, - } - - kernel = get_kernel(kernel_type="RN", kernel_params=rnlikelihood_params) - - # This could be causing problems - mean = get_mean(mean_type=mean_type, mean_params=mean_params) - - # gp = GaussianProcess(kernel, kwargs["Times"], mean=mean) - gp = GaussianProcess(kernel, kwargs["Times"], mean_value=mean(kwargs["Times"])) - return gp.log_probability(kwargs["counts"]) - - if (kernel_type == "RN") & (mean_type == "fred"): - return fred_RNlog_likelihood - - @jit - def fredQPOlog_likelihood(arn, crn, aqpo, cqpo, freq, A, t0, phi, delta): - qpolikelihood_params = { - "arn": arn, - "crn": crn, - "aqpo": aqpo, - "cqpo": cqpo, - "freq": freq, - } - - mean_params = { - "A": A, - "t0": t0, - "phi": phi, - "delta": delta, - } - - kernel = get_kernel(kernel_type="QPO_plus_RN", kernel_params=qpolikelihood_params) - - mean = get_mean(mean_type=mean_type, mean_params=mean_params) - - # gp = GaussianProcess(kernel, kwargs["Times"], mean=mean) - gp = GaussianProcess(kernel, kwargs["Times"], mean_value=mean(kwargs["Times"])) - return gp.log_probability(kwargs["counts"]) - - if (kernel_type == "QPO_plus_RN") & (mean_type == "fred"): - return fredQPOlog_likelihood - - -class GPResult: - """ - Makes a GP regressor for a given GP class and a prior over it. - Provides the sampled hyperparameters and tabulates their charachtersistics - Using jaxns for nested sampling and evidence analysis - - Parameters - ---------- - GP: class: GP - The initial GP class, on which we will apply our regressor. - - prior_type: string tuple - Has two strings with the first being the name of the kernel type - and the secound being the mean type for the prior - - prior_parameters: dict, default = None - Dictionary containing the parameters for the mean and kernel priors - The keys should be accourding to the selected kernel and mean - prior coressponding to the prior_type - By default, it takes a value None, and the kernel and mean priors are - then bulit using the pre-set parameters. - - Other Parameters - ---------------- - lc: Stingray.Lightcurve object - The lightcurve on which the gaussian process regression, is to be done + Parameters + ---------- + prior_model: jaxns.prior.PriorModelType object + A prior generator object - """ + likelihood_model: jaxns.types.LikelihoodType object + A likelihood fucntion which takes in the arguments of the prior + model and returns the loglikelihood of the model - def __init__(self, GP: GP, prior_type: tuple, prior_parameters=None) -> None: - self.gpclass = GP - self.prior_type = prior_type - self.prior_parameters = prior_parameters - self.lc = GP.lc + Returns + ---------- + Results: jaxns.results.NestedSamplerResults object + The results of the nested sampling process - def run_sampling(self): """ - Runs a sampling process for the hyperparameters for the GP model. - Based on No U turn Sampling from the numpyro module - """ - - dict = {"Times": self.lc.time, "counts": self.lc.counts} - self.prior_model = get_prior(self.prior_type[0], self.prior_type[1], **dict) - self.likelihood = get_likelihood(self.prior_type[0], self.prior_type[1], **dict) - NSmodel = Model(prior_model=self.prior_model, log_likelihood=self.likelihood) + self.prior_model = prior_model + self.likelihood_model = likelihood_model + NSmodel = Model(prior_model=self.prior_model, log_likelihood=self.likelihood_model) NSmodel.sanity_check(random.PRNGKey(10), S=100) self.Exact_ns = ExactNestedSampler(NSmodel, num_live_points=500, max_samples=1e4) From ada63a016d777b97e53b3a118df716d0e2163a16 Mon Sep 17 00:00:00 2001 From: Gaurav17Joshi Date: Tue, 27 Jun 2023 23:01:31 +0530 Subject: [PATCH 28/50] Added kernel, mean tests --- stingray/modeling/gpmodeling.py | 14 +++-- stingray/modeling/tests/test_gpmodeling.py | 60 ++++++++++++++++++++++ 2 files changed, 70 insertions(+), 4 deletions(-) create mode 100644 stingray/modeling/tests/test_gpmodeling.py diff --git a/stingray/modeling/gpmodeling.py b/stingray/modeling/gpmodeling.py index ab810caf3..a26b74a09 100644 --- a/stingray/modeling/gpmodeling.py +++ b/stingray/modeling/gpmodeling.py @@ -105,9 +105,11 @@ def _gaussian(t, mean_params): ------- The y values for the gaussian flare. """ - return mean_params["A"] * jnp.exp( - -((t - mean_params["t0"]) ** 2) / (2 * (mean_params["sig"] ** 2)) - ) + A = jnp.atleast_1d(mean_params["A"])[:, jnp.newaxis] + t0 = jnp.atleast_1d(mean_params["t0"])[:, jnp.newaxis] + sig = jnp.atleast_1d(mean_params["sig"])[:, jnp.newaxis] + + return jnp.sum(A * jnp.exp(-((t - t0) ** 2) / (2 * (sig**2))), axis=0) def _exponential(t, mean_params): @@ -128,7 +130,11 @@ def _exponential(t, mean_params): ------- The y values for exponential flare. """ - return mean_params["A"] * jnp.exp(-jnp.abs((t - mean_params["t0"])) / mean_params["sig"]) + A = jnp.atleast_1d(mean_params["A"])[:, jnp.newaxis] + t0 = jnp.atleast_1d(mean_params["t0"])[:, jnp.newaxis] + sig = jnp.atleast_1d(mean_params["sig"])[:, jnp.newaxis] + + return jnp.sum(A * jnp.exp(-jnp.abs(t - t0) / (2 * (sig**2))), axis=0) def _constant(t, mean_params): diff --git a/stingray/modeling/tests/test_gpmodeling.py b/stingray/modeling/tests/test_gpmodeling.py new file mode 100644 index 000000000..19e2305dc --- /dev/null +++ b/stingray/modeling/tests/test_gpmodeling.py @@ -0,0 +1,60 @@ +import jax +import jax.numpy as jnp +import numpy as np +import tensorflow_probability.substrates.jax as tfp +import matplotlib.pyplot as plt + +from tinygp import GaussianProcess, kernels +from stingray.modeling.gpmodeling import get_kernel, get_mean, GP + + +class Testget_kernel(object): + def setup_class(self): + self.x = np.linspace(0, 1, 5) + self.kernel_params = {"arn": 1.0, "aqpo": 1.0, "crn": 1.0, "cqpo": 1.0, "freq": 1.0} + + def test_get_kernel_qpo_plus_rn(self): + kernel_qpo_plus_rn = kernels.quasisep.Exp( + scale=1 / 1, sigma=(1) ** 0.5 + ) + kernels.quasisep.Celerite( + a=1, + b=0.0, + c=1, + d=2 * jnp.pi * 1, + ) + kernel_qpo_plus_rn_test = get_kernel("QPO_plus_RN", self.kernel_params) + assert ( + kernel_qpo_plus_rn(self.x, jnp.array([0.0])) + == kernel_qpo_plus_rn_test(self.x, jnp.array([0.0])) + ).all() + + def test_get_kernel_rn(self): + kernel_rn = kernels.quasisep.Exp(scale=1 / 1, sigma=(1) ** 0.5) + kernel_rn_test = get_kernel("RN", self.kernel_params) + assert ( + kernel_rn(self.x, jnp.array([0.0])) == kernel_rn_test(self.x, jnp.array([0.0])) + ).all() + + +class Testget_mean(object): + def setup_class(self): + self.t = np.linspace(0, 5, 10) + self.mean_params_gaussian = { + "A": jnp.array([3.0, 4.0]), + "t0": jnp.array([0.2, 0.7]), + "sig": jnp.array([0.2, 0.1]), + } + + def test_get_mean_gaussian(self): + result_gaussian = 3 * jnp.exp(-((self.t - 0.2) ** 2) / (2 * (0.2**2))) + 4 * jnp.exp( + -((self.t - 0.7) ** 2) / (2 * (0.1**2)) + ) + assert (get_mean("gaussian", self.mean_params_gaussian)(self.t) == result_gaussian).all() + + def test_get_mean_exponential(self): + result_exponential = 3 * jnp.exp(-jnp.abs(self.t - 0.2) / (2 * (0.2**2))) + 4 * jnp.exp( + -jnp.abs(self.t - 0.7) / (2 * (0.1**2)) + ) + assert ( + get_mean("exponential", self.mean_params_gaussian)(self.t) == result_exponential + ).all() From c1a08a39ba292e5ca0ed4232746ffb62d41dc8dd Mon Sep 17 00:00:00 2001 From: Gaurav17Joshi Date: Wed, 28 Jun 2023 01:49:25 +0530 Subject: [PATCH 29/50] Added Multimean and tests --- stingray/modeling/gpmodeling.py | 51 +++++++++++++------- stingray/modeling/tests/test_gpmodeling.py | 55 ++++++++++++++++++++-- 2 files changed, 85 insertions(+), 21 deletions(-) diff --git a/stingray/modeling/gpmodeling.py b/stingray/modeling/gpmodeling.py index a26b74a09..2c9426060 100644 --- a/stingray/modeling/gpmodeling.py +++ b/stingray/modeling/gpmodeling.py @@ -174,10 +174,19 @@ def _skew_gaussian(t, mean_params): ------- The y values for skew gaussian flare. """ - return mean_params["A"] * jnp.where( - t > mean_params["t0"], - jnp.exp(-((t - mean_params["t0"]) ** 2) / (2 * (mean_params["sig2"] ** 2))), - jnp.exp(-((t - mean_params["t0"]) ** 2) / (2 * (mean_params["sig1"] ** 2))), + A = jnp.atleast_1d(mean_params["A"])[:, jnp.newaxis] + t0 = jnp.atleast_1d(mean_params["t0"])[:, jnp.newaxis] + sig1 = jnp.atleast_1d(mean_params["sig1"])[:, jnp.newaxis] + sig2 = jnp.atleast_1d(mean_params["sig2"])[:, jnp.newaxis] + + return jnp.sum( + A + * jnp.where( + t > t0, + jnp.exp(-((t - t0) ** 2) / (2 * (sig2**2))), + jnp.exp(-((t - t0) ** 2) / (2 * (sig1**2))), + ), + axis=0, ) @@ -201,10 +210,19 @@ def _skew_exponential(t, mean_params): ------- The y values for exponential flare. """ - return mean_params["A"] * jnp.where( - t > mean_params["t0"], - jnp.exp(-(t - mean_params["t0"]) / mean_params["sig2"]), - jnp.exp((t - mean_params["t0"]) / mean_params["sig1"]), + A = jnp.atleast_1d(mean_params["A"])[:, jnp.newaxis] + t0 = jnp.atleast_1d(mean_params["t0"])[:, jnp.newaxis] + sig1 = jnp.atleast_1d(mean_params["sig1"])[:, jnp.newaxis] + sig2 = jnp.atleast_1d(mean_params["sig2"])[:, jnp.newaxis] + + return jnp.sum( + A + * jnp.where( + t > t0, + jnp.exp(-(t - t0) / (2 * (sig2**2))), + jnp.exp((t - t0) / (2 * (sig1**2))), + ), + axis=0, ) @@ -228,16 +246,13 @@ def _fred(t, mean_params): ------- The y values for exponential flare. """ - return ( - mean_params["A"] - * jnp.exp( - -mean_params["phi"] - * ( - (t + mean_params["delta"]) / mean_params["t0"] - + mean_params["t0"] / (t + mean_params["delta"]) - ) - ) - * jnp.exp(2 * mean_params["phi"]) + A = jnp.atleast_1d(mean_params["A"])[:, jnp.newaxis] + t0 = jnp.atleast_1d(mean_params["t0"])[:, jnp.newaxis] + phi = jnp.atleast_1d(mean_params["phi"])[:, jnp.newaxis] + delta = jnp.atleast_1d(mean_params["delta"])[:, jnp.newaxis] + + return jnp.sum( + A * jnp.exp(-phi * ((t + delta) / t0 + t0 / (t + delta))) * jnp.exp(2 * phi), axis=0 ) diff --git a/stingray/modeling/tests/test_gpmodeling.py b/stingray/modeling/tests/test_gpmodeling.py index 19e2305dc..27ddadcec 100644 --- a/stingray/modeling/tests/test_gpmodeling.py +++ b/stingray/modeling/tests/test_gpmodeling.py @@ -39,22 +39,71 @@ def test_get_kernel_rn(self): class Testget_mean(object): def setup_class(self): self.t = np.linspace(0, 5, 10) - self.mean_params_gaussian = { + self.mean_params = { "A": jnp.array([3.0, 4.0]), "t0": jnp.array([0.2, 0.7]), "sig": jnp.array([0.2, 0.1]), } + self.skew_mean_params = { + "A": jnp.array([3.0, 4.0]), + "t0": jnp.array([0.2, 0.7]), + "sig1": jnp.array([0.2, 0.1]), + "sig2": jnp.array([0.3, 0.4]), + } + self.fred_mean_params = { + "A": jnp.array([3.0, 4.0]), + "t0": jnp.array([0.2, 0.7]), + "phi": jnp.array([4.0, 5.0]), + "delta": jnp.array([0.3, 0.4]), + } def test_get_mean_gaussian(self): result_gaussian = 3 * jnp.exp(-((self.t - 0.2) ** 2) / (2 * (0.2**2))) + 4 * jnp.exp( -((self.t - 0.7) ** 2) / (2 * (0.1**2)) ) - assert (get_mean("gaussian", self.mean_params_gaussian)(self.t) == result_gaussian).all() + assert (get_mean("gaussian", self.mean_params)(self.t) == result_gaussian).all() def test_get_mean_exponential(self): result_exponential = 3 * jnp.exp(-jnp.abs(self.t - 0.2) / (2 * (0.2**2))) + 4 * jnp.exp( -jnp.abs(self.t - 0.7) / (2 * (0.1**2)) ) + assert (get_mean("exponential", self.mean_params)(self.t) == result_exponential).all() + + def test_get_mean_constant(self): + result_constant = 3 * jnp.ones_like(self.t) + const_param_dict = {"A": jnp.array([3.0])} + assert (get_mean("constant", const_param_dict)(self.t) == result_constant).all() + + def test_get_mean_skew_gaussian(self): + result_skew_gaussian = 3.0 * jnp.where( + self.t > 0.2, + jnp.exp(-((self.t - 0.2) ** 2) / (2 * (0.3**2))), + jnp.exp(-((self.t - 0.2) ** 2) / (2 * (0.2**2))), + ) + 4.0 * jnp.where( + self.t > 0.7, + jnp.exp(-((self.t - 0.7) ** 2) / (2 * (0.4**2))), + jnp.exp(-((self.t - 0.7) ** 2) / (2 * (0.1**2))), + ) assert ( - get_mean("exponential", self.mean_params_gaussian)(self.t) == result_exponential + get_mean("skew_gaussian", self.skew_mean_params)(self.t) == result_skew_gaussian ).all() + + def test_get_mean_skew_exponential(self): + result_skew_exponential = 3.0 * jnp.where( + self.t > 0.2, + jnp.exp(-jnp.abs(self.t - 0.2) / (2 * (0.3**2))), + jnp.exp(-jnp.abs(self.t - 0.2) / (2 * (0.2**2))), + ) + 4.0 * jnp.where( + self.t > 0.7, + jnp.exp(-jnp.abs(self.t - 0.7) / (2 * (0.4**2))), + jnp.exp(-jnp.abs(self.t - 0.7) / (2 * (0.1**2))), + ) + assert ( + get_mean("skew_exponential", self.skew_mean_params)(self.t) == result_skew_exponential + ).all() + + def test_get_mean_fred(self): + result_fred = 3.0 * jnp.exp(-4.0 * ((self.t + 0.3) / 0.2 + 0.2 / (self.t + 0.3))) * jnp.exp( + 2 * 4.0 + ) + 4.0 * jnp.exp(-5.0 * ((self.t + 0.4) / 0.7 + 0.7 / (self.t + 0.4))) * jnp.exp(2 * 5.0) + assert (get_mean("fred", self.fred_mean_params)(self.t) == result_fred).all() From 827cdf77d221f90a90fea3e0717ac61fee70c6ec Mon Sep 17 00:00:00 2001 From: Gaurav17Joshi Date: Thu, 29 Jun 2023 14:55:38 +0530 Subject: [PATCH 30/50] Testing for get_pior_params --- stingray/modeling/tests/test_gpmodeling.py | 84 +++++++++++++++++++++- 1 file changed, 83 insertions(+), 1 deletion(-) diff --git a/stingray/modeling/tests/test_gpmodeling.py b/stingray/modeling/tests/test_gpmodeling.py index 27ddadcec..abc9632cf 100644 --- a/stingray/modeling/tests/test_gpmodeling.py +++ b/stingray/modeling/tests/test_gpmodeling.py @@ -5,7 +5,7 @@ import matplotlib.pyplot as plt from tinygp import GaussianProcess, kernels -from stingray.modeling.gpmodeling import get_kernel, get_mean, GP +from stingray.modeling.gpmodeling import get_kernel, get_mean, get_gp_params class Testget_kernel(object): @@ -107,3 +107,85 @@ def test_get_mean_fred(self): 2 * 4.0 ) + 4.0 * jnp.exp(-5.0 * ((self.t + 0.4) / 0.7 + 0.7 / (self.t + 0.4))) * jnp.exp(2 * 5.0) assert (get_mean("fred", self.fred_mean_params)(self.t) == result_fred).all() + + class Testget_gp_params(object): + def setup_class(self): + pass + + def test_get_gp_params_rn(self): + assert get_gp_params("RN", "gaussian") == ["arn", "crn", "A", "t0", "sig"] + assert get_gp_params("RN", "constant") == ["arn", "crn", "A"] + assert get_gp_params("RN", "skew_gaussian") == ["arn", "crn", "A", "t0", "sig1", "sig2"] + assert get_gp_params("RN", "skew_exponential") == [ + "arn", + "crn", + "A", + "t0", + "sig1", + "sig2", + ] + assert get_gp_params("RN", "exponential") == ["arn", "crn", "A", "t0", "sig"] + assert get_gp_params("RN", "fred") == ["arn", "crn", "A", "t0", "delta", "phi"] + + def test_get_gp_params_qpo_plus_rn(self): + assert get_gp_params("QPO_plus_RN", "gaussian") == [ + "arn", + "crn", + "aqpo", + "cqpo", + "freq", + "A", + "t0", + "sig", + ] + assert get_gp_params("QPO_plus_RN", "constant") == [ + "arn", + "crn", + "aqpo", + "cqpo", + "freq", + "A", + ] + assert get_gp_params("QPO_plus_RN", "skew_gaussian") == [ + "arn", + "crn", + "aqpo", + "cqpo", + "freq", + "A", + "t0", + "sig1", + "sig2", + ] + assert get_gp_params("QPO_plus_RN", "skew_exponential") == [ + "arn", + "crn", + "aqpo", + "cqpo", + "freq", + "A", + "t0", + "sig1", + "sig2", + ] + assert get_gp_params("QPO_plus_RN", "exponential") == [ + "arn", + "crn", + "aqpo", + "cqpo", + "freq", + "A", + "t0", + "sig", + ] + assert get_gp_params("QPO_plus_RN", "fred") == [ + "arn", + "crn", + "aqpo", + "cqpo", + "freq", + "A", + "t0", + "delta", + "phi", + ] From 8fa0059d74807936b96f27bc13bd08e5867c6628 Mon Sep 17 00:00:00 2001 From: Gaurav17Joshi Date: Tue, 11 Jul 2023 18:23:41 +0530 Subject: [PATCH 31/50] Changed the GP class --- stingray/modeling/gpmodeling.py | 216 +++++++++++++++++--------------- 1 file changed, 112 insertions(+), 104 deletions(-) diff --git a/stingray/modeling/gpmodeling.py b/stingray/modeling/gpmodeling.py index 2c9426060..b82803f44 100644 --- a/stingray/modeling/gpmodeling.py +++ b/stingray/modeling/gpmodeling.py @@ -15,13 +15,14 @@ # from jaxns import analytic_log_evidence from jaxns import Prior, Model +from jaxns.utils import resample jax.config.update("jax_enable_x64", True) tfpd = tfp.distributions tfpb = tfp.bijectors -__all__ = ["GP"] +__all__ = ["GPResult"] def get_kernel(kernel_type, kernel_params): @@ -349,37 +350,23 @@ def likelihood_model(*args): return likelihood_model -class GP: +class GPResult: """ - Makes a GP object which takes in a Stingray.Lightcurve and fits a Gaussian - Process on the lightcurve data, for the given kernel. + Makes a GPResult object which takes in a Stingray.Lightcurve and samples parameters of a model + (Gaussian Process) based on the given prior and log_likelihood function. Parameters ---------- lc: Stingray.Lightcurve object - The lightcurve on which the gaussian process, is to be fitted - - Model_type: string tuple - Has two strings with the first being the name of the kernel type - and the secound being the mean type - - Model_parameter: dict, default = None - Dictionary conatining the parameters for the mean and kernel - The keys should be accourding to the selected kernel and mean - coressponding to the Model_type - By default, it takes a value None, and the kernel and mean are - then bulit using the pre-set parameters. + The lightcurve on which the bayesian inference is to be done Other Parameters ---------------- - kernel: class: `TinyGp.kernel` object - The tinygp kernel for the GP + time : class: np.array + The array containing the times of the lightcurve - mean: class: `TinyGp.mean` object - The tinygp mean for the GP - - maingp: class: `TinyGp.GaussianProcess` object - The tinygp gaussian process made on the lightcurve + counts : class: np.array + The array containing the photon counts of the lightcurve """ @@ -387,77 +374,7 @@ def __init__(self, Lc: Lightcurve) -> None: self.lc = Lc self.time = Lc.time self.counts = Lc.counts - - def fit(self, kernel=None, mean=None, **kwargs): - self.kernel = kernel - self.mean = mean - self.maingp = GaussianProcess( - self.kernel, self.time, mean_value=self.mean(self.time), diag=kwargs["diag"] - ) - - def get_logprob(self): - """ - Returns the logprobability of the lightcurves counts for the - given kernel for the Gaussian Process - """ - cond = self.maingp.condition(self.lc.counts) - return cond.log_probability - - def plot_kernel(self): - """ - Plots the kernel of the Gaussian Process - """ - X = self.lc.time - Y = self.kernel(X, np.array([0.0])) - plt.plot(X, Y) - plt.xlabel("distance") - plt.ylabel("Value") - plt.title("Kernel Function") - - def plot_originalgp(self, sample_no=1, seed=0): - """ - Plots samples obtained from the gaussian process for the kernel - - Parameters - ---------- - sample_no: int , default = 1 - Number of GP samples to be taken - - """ - X_test = self.lc.time - _, ax = plt.subplots(1, 1, figsize=(10, 3)) - y_samp = self.maingp.sample(jax.random.PRNGKey(seed), shape=(sample_no,)) - ax.plot(X_test, y_samp[0], "C0", lw=0.5, alpha=0.5, label="samples") - ax.plot(X_test, y_samp[1:].T, "C0", lw=0.5, alpha=0.5) - ax.set_xlabel("time") - ax.set_ylabel("counts") - ax.legend(loc="best") - - def plot_gp(self, sample_no=1, seed=0): - """ - Plots gaussian process, conditioned on the lightcurve - Also, plots the lightcurve along with it - - Parameters - ---------- - sample_no: int , default = 1 - Number of GP samples to be taken - - """ - X_test = self.lc.time - - _, ax = plt.subplots(1, 1, figsize=(10, 3)) - _, cond_gp = self.maingp.condition(self.lc.counts, X_test) - mu = cond_gp.mean - # std = np.sqrt(cond_gp.variance) - - ax.plot(self.lc.time, self.lc.counts, lw=2, color="blue", label="Lightcurve") - ax.plot(X_test, mu, "C1", label="Gaussian Process") - y_samp = cond_gp.sample(jax.random.PRNGKey(seed), shape=(sample_no,)) - ax.plot(X_test, y_samp[0], "C0", lw=0.5, alpha=0.5) - ax.set_xlabel("time") - ax.set_ylabel("counts") - ax.legend(loc="best") + self.Result = None def sample(self, prior_model=None, likelihood_model=None, **kwargs): """ @@ -493,6 +410,12 @@ def sample(self, prior_model=None, likelihood_model=None, **kwargs): self.Results = self.Exact_ns.to_results(State, Termination_reason) print("Simulation Complete") + def get_evidence(self): + """ + Returns the log evidence of the model + """ + return self.Results.log_Z_mean + def print_summary(self): """ Prints a summary table for the model parameters @@ -511,24 +434,109 @@ def plot_cornerplot(self): """ self.Exact_ns.plot_cornerplot(self.Results) - def get_parameters(self): + def get_parameters_names(self): + """ + Returns the names of the parameters + """ + return sorted(self.Results.samples.keys()) + + def get_max_posterior_parameters(self): """ Returns the optimal parameters for the model based on the NUTS sampling """ + max_post_idx = jnp.argmax(self.Results.log_posterior_density) + map_points = jax.tree_map(lambda x: x[max_post_idx], self.Results.samples) - pass + return map_points - def plot_posterior(self, X_test): + def get_max_likelihood_parameters(self): + """ + Retruns the maximum likelihood parameters """ - Plots posterior gaussian process, conditioned on the lightcurve - Also, plots the lightcurve along with it + max_like_idx = jnp.argmax(self.Results.log_L_samples) + max_like_points = jax.tree_map(lambda x: x[max_like_idx], self.Results.samples) - Parameters - ---------- - X_test: jnp.array - Array over which the Gaussian process values are to be obtained - Can be made default with lc.times as default + return max_like_points + + def posterior_plot(self, name: str, n=0): + """ + Plots the posterior histogram for the given parameter + """ + nsamples = self.Results.total_num_samples + samples = self.Results.samples[name].reshape((nsamples, -1))[:, n] + plt.hist( + samples, bins="auto", density=True, alpha=1.0, label=name, fc="None", edgecolor="black" + ) + mean1 = jnp.mean(self.Results.samples[name]) + std1 = jnp.std(self.Results.samples[name]) + plt.axvline(mean1, color="red", linestyle="dashed", label="mean") + plt.axvline(mean1 + std1, color="green", linestyle="dotted") + plt.axvline(mean1 - std1, linestyle="dotted", color="green") + plt.legend() + plt.plot() + + pass + + def weighted_posterior_plot(self, name: str, n=0, rkey=random.PRNGKey(1234)): + """ + Returns the weighted posterior histogram for the given parameter + """ + nsamples = self.Results.total_num_samples + log_p = self.Results.log_dp_mean + samples = self.Results.samples[name].reshape((nsamples, -1))[:, n] + + weights = jnp.where(jnp.isfinite(samples), jnp.exp(log_p), 0.0) + log_weights = jnp.where(jnp.isfinite(samples), log_p, -jnp.inf) + samples_resampled = resample( + rkey, samples, log_weights, S=max(10, int(self.Results.ESS)), replace=True + ) + nbins = max(10, int(jnp.sqrt(self.Results.ESS)) + 1) + binsx = jnp.linspace(*jnp.percentile(samples_resampled, jnp.asarray([0, 100])), 2 * nbins) + + plt.hist( + np.asarray(samples_resampled), + bins=binsx, + density=True, + alpha=1.0, + label=name, + fc="None", + edgecolor="black", + ) + sample_mean = jnp.average(samples, weights=weights) + sample_std = jnp.sqrt(jnp.average((samples - sample_mean) ** 2, weights=weights)) + plt.axvline(sample_mean, color="red", linestyle="dashed", label="mean") + plt.axvline(sample_mean + sample_std, color="green", linestyle="dotted") + plt.axvline(sample_mean - sample_std, linestyle="dotted", color="green") + plt.legend() + plt.plot() + + def corner_plot(self, param1: str, param2: str, n1=0, n2=0, rkey=random.PRNGKey(1234)): """ + Plots the corner plot for the given parameters + """ + nsamples = self.Results.total_num_samples + log_p = self.Results.log_dp_mean + samples1 = self.Results.samples[param1].reshape((nsamples, -1))[:, n1] + samples2 = self.Results.samples[param2].reshape((nsamples, -1))[:, n2] + + log_weights = jnp.where(jnp.isfinite(samples2), log_p, -jnp.inf) + nbins = max(10, int(jnp.sqrt(self.Results.ESS)) + 1) + + samples_resampled = resample( + rkey, + jnp.stack([samples1, samples2], axis=-1), + log_weights, + S=max(10, int(self.Results.ESS)), + replace=True, + ) + plt.hist2d( + samples_resampled[:, 1], + samples_resampled[:, 0], + bins=(nbins, nbins), + density=True, + cmap="GnBu", + ) + plt.plot() pass From e00d43860db8b94a4c3e291bf2281f5da6db742b Mon Sep 17 00:00:00 2001 From: Gaurav17Joshi Date: Tue, 18 Jul 2023 19:46:07 +0530 Subject: [PATCH 32/50] Improved library imports --- stingray/modeling/gpmodeling.py | 53 +++++++++++++++++++++++++-------- 1 file changed, 40 insertions(+), 13 deletions(-) diff --git a/stingray/modeling/gpmodeling.py b/stingray/modeling/gpmodeling.py index b82803f44..f9178ce37 100644 --- a/stingray/modeling/gpmodeling.py +++ b/stingray/modeling/gpmodeling.py @@ -1,26 +1,41 @@ import numpy as np import matplotlib.pyplot as plt -import jax -import jax.numpy as jnp import functools -import tensorflow_probability.substrates.jax as tfp +from stingray import Lightcurve + +try: + import jax +except ImportError: + raise ImportError("Jax not installed") +import jax.numpy as jnp from jax import jit, random -from tinygp import GaussianProcess, kernels -from stingray import Lightcurve +jax.config.update("jax_enable_x64", True) -from jaxns import ExactNestedSampler -from jaxns import TerminationCondition +try: + from tinygp import GaussianProcess, kernels -# from jaxns import analytic_log_evidence -from jaxns import Prior, Model -from jaxns.utils import resample + can_make_gp = True +except ImportError: + can_make_gp = False -jax.config.update("jax_enable_x64", True) +try: + from jaxns import ExactNestedSampler, TerminationCondition, Prior, Model + from jaxns.utils import resample + + can_sample = True +except ImportError: + can_sample = False +try: + import tensorflow_probability.substrates.jax as tfp + + tfpd = tfp.distributions + tfpb = tfp.bijectors + tfp_available = True +except ImportError: + tfp_available = False -tfpd = tfp.distributions -tfpb = tfp.bijectors __all__ = ["GPResult"] @@ -41,6 +56,9 @@ def get_kernel(kernel_type, kernel_params): Should contain the parameters for the selected kernel """ + if not can_make_gp: + raise ImportError("Tinygp is required to make kernels") + if kernel_type == "QPO_plus_RN": kernel = kernels.quasisep.Exp( scale=1 / kernel_params["crn"], sigma=(kernel_params["arn"]) ** 0.5 @@ -303,6 +321,11 @@ def get_prior(params_list, prior_dict): Non Windowed arguments """ + if not can_sample: + raise ImportError("Jaxns not installed. Cannot make jaxns specific prior.") + + if not tfp_available: + raise ImportError("Tensorflow probability required to make priors.") def prior_model(): prior_list = [] @@ -336,6 +359,8 @@ def get_likelihood(params_list, kernel_type, mean_type, **kwargs): The type of mean to be used in the model. """ + if not can_make_gp: + raise ImportError("Tinygp is required to make the GP model.") @jit def likelihood_model(*args): @@ -396,6 +421,8 @@ def sample(self, prior_model=None, likelihood_model=None, **kwargs): The results of the nested sampling process """ + if not can_sample: + raise ImportError("Jaxns not installed! Can't sample!") self.prior_model = prior_model self.likelihood_model = likelihood_model From b031ed1f99ede09011301c028c814f8babcdff10 Mon Sep 17 00:00:00 2001 From: Gaurav17Joshi Date: Tue, 18 Jul 2023 21:59:42 +0530 Subject: [PATCH 33/50] Primary tests for GPresult class --- stingray/modeling/tests/test_gpmodeling.py | 77 +++++++++++++++++++++- 1 file changed, 76 insertions(+), 1 deletion(-) diff --git a/stingray/modeling/tests/test_gpmodeling.py b/stingray/modeling/tests/test_gpmodeling.py index abc9632cf..f9949c3a6 100644 --- a/stingray/modeling/tests/test_gpmodeling.py +++ b/stingray/modeling/tests/test_gpmodeling.py @@ -1,11 +1,22 @@ import jax import jax.numpy as jnp +from jax import random + +jax.config.update("jax_enable_x64", True) + import numpy as np -import tensorflow_probability.substrates.jax as tfp import matplotlib.pyplot as plt from tinygp import GaussianProcess, kernels from stingray.modeling.gpmodeling import get_kernel, get_mean, get_gp_params +from stingray.modeling.gpmodeling import get_prior, get_likelihood, GPResult +from stingray import Lightcurve + +import tensorflow_probability.substrates.jax as tfp + +tfpd = tfp.distributions + +from jaxns import ExactNestedSampler, TerminationCondition, Prior, Model class Testget_kernel(object): @@ -189,3 +200,67 @@ def test_get_gp_params_qpo_plus_rn(self): "delta", "phi", ] + + +class TestGPResult(object): + def setup_class(self): + self.Times = np.linspace(0, 1, 64) + kernel_params = { + "arn": jnp.exp(1.5), + "crn": jnp.exp(1.0), + } + mean_params = {"A": jnp.array([3.0]), "t0": jnp.array([0.2]), "sig": jnp.array([0.2])} + kernel = get_kernel("RN", kernel_params) + mean = get_mean("gaussian", mean_params) + + gp = GaussianProcess(kernel=kernel, X=self.Times, mean_value=mean(self.Times)) + self.counts = gp.sample(key=jax.random.PRNGKey(6)) + + lc = Lightcurve(time=self.Times, counts=self.counts, dt=self.Times[1] - self.Times[0]) + + self.params_list = get_gp_params(kernel_type="RN", mean_type="gaussian") + + T = self.Times[-1] - self.Times[0] + f = 1 / (self.Times[1] - self.Times[0]) + span = jnp.max(self.counts) - jnp.min(self.counts) + + # The prior dictionary, with suitable tfpd prior distributions + prior_dict = { + "A": tfpd.Uniform(low=0.1 * span, high=2 * span), + "t0": tfpd.Uniform(low=self.Times[0] - 0.1 * T, high=self.Times[-1] + 0.1 * T), + "sig": tfpd.Uniform(low=0.5 * 1 / f, high=2 * T), + "arn": tfpd.Uniform(low=0.1 * span, high=2 * span), + "crn": tfpd.Uniform(low=jnp.log(1 / T), high=jnp.log(f)), + } + + prior_model = get_prior(self.params_list, prior_dict) + likelihood_model = get_likelihood( + self.params_list, + kernel_type="RN", + mean_type="gaussian", + Times=self.Times, + counts=self.counts, + ) + + NSmodel = Model(prior_model=prior_model, log_likelihood=likelihood_model) + NSmodel.sanity_check(random.PRNGKey(10), S=100) + + Exact_ns = ExactNestedSampler(NSmodel, num_live_points=500, max_samples=1e4) + Termination_reason, State = Exact_ns( + random.PRNGKey(42), term_cond=TerminationCondition(live_evidence_frac=1e-4) + ) + self.Results = Exact_ns.to_results(State, Termination_reason) + + self.gpresult = GPResult(lc) + self.gpresult.sample(prior_model=prior_model, likelihood_model=likelihood_model) + + def test_sample(self): + for key in self.params_list: + assert (self.Results.samples[key]).all() == (self.gpresult.Results.samples[key]).all() + + def test_get_evidence(self): + assert self.Results.log_Z_mean == self.gpresult.Results.log_Z_mean + + def plot_diagnostics(self): + self.gpresult.plot_diagnostics() + assert plt.fignum_exists(1) From 5d95841494fc4765ef3621ef128bc2b8f085eeec Mon Sep 17 00:00:00 2001 From: Gaurav17Joshi Date: Wed, 19 Jul 2023 19:20:25 +0530 Subject: [PATCH 34/50] Added docstrings --- stingray/modeling/gpmodeling.py | 112 ++++++++++++++++++--- stingray/modeling/tests/test_gpmodeling.py | 11 +- 2 files changed, 106 insertions(+), 17 deletions(-) diff --git a/stingray/modeling/gpmodeling.py b/stingray/modeling/gpmodeling.py index f9178ce37..38fadbdef 100644 --- a/stingray/modeling/gpmodeling.py +++ b/stingray/modeling/gpmodeling.py @@ -43,7 +43,7 @@ def get_kernel(kernel_type, kernel_params): """ Function for producing the kernel for the Gaussian Process. - Returns the selected Tinygp kernel + Returns the selected Tinygp kernel for the given parameters. Parameters ---------- @@ -275,14 +275,38 @@ def _fred(t, mean_params): ) -def get_kernel_params(kernel_type): +def _get_kernel_params(kernel_type): + """ + Generates a list of the parameters for the kernel for the GP model based on the kernel type. + + Parameters + ---------- + kernel_type: string + The type of kernel to be used for the Gaussian Process model + + Returns + ------- + A list of the parameters for the kernel for the GP model + """ if kernel_type == "RN": return ["arn", "crn"] elif kernel_type == "QPO_plus_RN": return ["arn", "crn", "aqpo", "cqpo", "freq"] -def get_mean_params(mean_type): +def _get_mean_params(mean_type): + """ + Generates a list of the parameters for the mean for the GP model based on the mean type. + + Parameters + ---------- + mean_type: string + The type of mean to be used for the Gaussian Process model + + Returns + ------- + A list of the parameters for the mean for the GP model + """ if (mean_type == "gaussian") or (mean_type == "exponential"): return ["A", "t0", "sig"] elif mean_type == "constant": @@ -294,15 +318,39 @@ def get_mean_params(mean_type): def get_gp_params(kernel_type, mean_type): - kernel_params = get_kernel_params(kernel_type) - mean_params = get_mean_params(mean_type) + """ + Generates a list of the parameters for the GP model based on the kernel and mean type. + To be used to set the order of the parameters for `get_prior` and `get_likelihood` functions. + + Parameters + ---------- + kernel_type: string + The type of kernel to be used for the Gaussian Process model + + mean_type: string + The type of mean to be used for the Gaussian Process model + + Returns + ------- + A list of the parameters for the GP model + + Examples + -------- + >>> get_gp_params("QPO_plus_RN", "gaussian") + ['arn', 'crn', 'aqpo', 'cqpo', 'freq', 'A', 't0', 'sig'] + """ + kernel_params = _get_kernel_params(kernel_type) + mean_params = _get_mean_params(mean_type) kernel_params.extend(mean_params) return kernel_params def get_prior(params_list, prior_dict): """ - A prior generator function based on given values + A prior generator function based on given values. + Makes a jaxns specific prior function based on the given prior dictionary. + Jaxns requires the parameters of the prior function and log_likelihood function to + be in the same order. This order is made according to the params_list. Parameters ---------- @@ -311,14 +359,35 @@ def get_prior(params_list, prior_dict): prior_dict: A dictionary of the priors of parameters to be used. + These parameters should be from tensorflow_probability distributions / Priors from jaxns + or special priors from jaxns. + **Note**: If jaxns priors are used, then the name given to them should be the same as + the corresponding name in the params_list. Returns ------- - The Prior function. + The Prior generator function. The arguments of the prior function are in the order of - Kernel arguments (RN arguments, QPO arguments), - Mean arguments - Non Windowed arguments + Kernel arguments (RN arguments, QPO arguments), + Mean arguments + Miscellaneous arguments + + Examples + -------- + A prior function for a Red Noise kernel and a Gaussian mean function + Obain the parameters list + >>> params_list = get_gp_params("RN", "gaussian") + + Make a prior dictionary using tensorflow_probability distributions + >>> prior_dict = { + ... "A": tfpd.Uniform(low = 1e-1, high = 2e+2), + ... "t0": tfpd.Uniform(low = 0.0 - 0.1, high = 1 + 0.1), + ... "sig": tfpd.Uniform(low = 0.5 * 1 / 20, high = 2 ), + ... "arn": tfpd.Uniform(low = 0.1 , high = 2 ), + ... "crn": tfpd.Uniform(low = jnp.log(1 /5), high = jnp.log(20)), + ... } + + >>> prior_model = get_prior(params_list, prior_dict) """ if not can_sample: @@ -342,7 +411,11 @@ def prior_model(): def get_likelihood(params_list, kernel_type, mean_type, **kwargs): """ - A likelihood generator function based on given values + A log likelihood generator function based on given values. + Makes a jaxns specific log likelihood function which takes in the + parameters in the order of the parameters list, and calculates the + log likelihood of the data given the parameters, and the model + (kernel, mean) of the GP model. Parameters ---------- @@ -358,6 +431,19 @@ def get_likelihood(params_list, kernel_type, mean_type, **kwargs): mean_type: The type of mean to be used in the model. + **kwargs: + The keyword arguments to be used in the log likelihood function. + **Note**: The keyword arguments Times and counts are necessary for + calculating the log likelihood. + Times: np.array or jnp.array + The time array of the lightcurve + counts: np.array or jnp.array + The photon counts array of the lightcurve + + Returns + ------- + The jaxns specific log likelihood function. + """ if not can_make_gp: raise ImportError("Tinygp is required to make the GP model.") @@ -502,8 +588,6 @@ def posterior_plot(self, name: str, n=0): plt.legend() plt.plot() - pass - def weighted_posterior_plot(self, name: str, n=0, rkey=random.PRNGKey(1234)): """ Returns the weighted posterior histogram for the given parameter @@ -565,5 +649,3 @@ def corner_plot(self, param1: str, param2: str, n1=0, n2=0, rkey=random.PRNGKey( cmap="GnBu", ) plt.plot() - - pass diff --git a/stingray/modeling/tests/test_gpmodeling.py b/stingray/modeling/tests/test_gpmodeling.py index f9949c3a6..01ee0a398 100644 --- a/stingray/modeling/tests/test_gpmodeling.py +++ b/stingray/modeling/tests/test_gpmodeling.py @@ -259,8 +259,15 @@ def test_sample(self): assert (self.Results.samples[key]).all() == (self.gpresult.Results.samples[key]).all() def test_get_evidence(self): - assert self.Results.log_Z_mean == self.gpresult.Results.log_Z_mean + assert self.Results.log_Z_mean == self.gpresult.get_evidence() - def plot_diagnostics(self): + def test_plot_diagnostics(self): self.gpresult.plot_diagnostics() assert plt.fignum_exists(1) + + def test_plot_cornerplot(self): + self.gpresult.plot_cornerplot() + assert plt.fignum_exists(1) + + def test_get_parameters_names(self): + assert sorted(self.params_list) == self.gpresult.get_parameters_names() From c013d4ac149bbc2ae8d4f4d68406baa78156be58 Mon Sep 17 00:00:00 2001 From: Gaurav17Joshi Date: Thu, 20 Jul 2023 02:50:37 +0530 Subject: [PATCH 35/50] Plot function changed, tests added --- stingray/modeling/gpmodeling.py | 159 ++++++++++++++++++++- stingray/modeling/tests/test_gpmodeling.py | 93 +++++++++++- 2 files changed, 242 insertions(+), 10 deletions(-) diff --git a/stingray/modeling/gpmodeling.py b/stingray/modeling/gpmodeling.py index 38fadbdef..d2e7f2917 100644 --- a/stingray/modeling/gpmodeling.py +++ b/stingray/modeling/gpmodeling.py @@ -571,9 +571,37 @@ def get_max_likelihood_parameters(self): return max_like_points - def posterior_plot(self, name: str, n=0): + def posterior_plot(self, name: str, n=0, axis=None, save=False, filename=None): """ Plots the posterior histogram for the given parameter + + Parameters + ---------- + name : str + Name of the parameter. + Should be from the names of the parameter list used or from the names of parameters + used in the prior_function + + n : int, default 0 + The index of the parameter to be plotted. + For multivariate parameters, the index of the specific parameter to be plotted. + + axis : list, tuple, string, default ``None`` + Parameter to set axis properties of ``matplotlib`` figure. For example + it can be a list like ``[xmin, xmax, ymin, ymax]`` or any other + acceptable argument for ``matplotlib.pyplot.axis()`` method. + + save : bool, optionalm, default ``False`` + If ``True``, save the figure with specified filename. + + filename : str + File name and path of the image to save. Depends on the boolean ``save``. + + Returns + ------- + plt : ``matplotlib.pyplot`` object + Reference to plot, call ``show()`` to display it + """ nsamples = self.Results.total_num_samples samples = self.Results.samples[name].reshape((nsamples, -1))[:, n] @@ -585,12 +613,56 @@ def posterior_plot(self, name: str, n=0): plt.axvline(mean1, color="red", linestyle="dashed", label="mean") plt.axvline(mean1 + std1, color="green", linestyle="dotted") plt.axvline(mean1 - std1, linestyle="dotted", color="green") + plt.title("Posterior Histogram of " + str(name)) + plt.xlabel(name) + plt.ylabel("Probability Density") plt.legend() - plt.plot() - def weighted_posterior_plot(self, name: str, n=0, rkey=random.PRNGKey(1234)): + if axis is not None: + plt.axis(axis) + + if save: + if filename is None: + plt.savefig(str(name) + "_Posterior_plot.png") + else: + plt.savefig(filename) + return plt + + def weighted_posterior_plot( + self, name: str, n=0, rkey=random.PRNGKey(1234), axis=None, save=False, filename=None + ): """ Returns the weighted posterior histogram for the given parameter + + Parameters + ---------- + name : str + Name of the parameter. + Should be from the names of the parameter list used or from the names of parameters + used in the prior_function + + n : int, default 0 + The index of the parameter to be plotted. + For multivariate parameters, the index of the specific parameter to be plotted. + + key: jax.random.PRNGKey, default ``random.PRNGKey(1234)`` + Random key for the weighted sampling + + axis : list, tuple, string, default ``None`` + Parameter to set axis properties of ``matplotlib`` figure. For example + it can be a list like ``[xmin, xmax, ymin, ymax]`` or any other + acceptable argument for ``matplotlib.pyplot.axis()`` method. + + save : bool, optionalm, default ``False`` + If ``True``, save the figure with specified filename. + + filename : str + File name and path of the image to save. Depends on the boolean ``save``. + + Returns + ------- + plt : ``matplotlib.pyplot`` object + Reference to plot, call ``show()`` to display it """ nsamples = self.Results.total_num_samples log_p = self.Results.log_dp_mean @@ -619,12 +691,72 @@ def weighted_posterior_plot(self, name: str, n=0, rkey=random.PRNGKey(1234)): plt.axvline(sample_mean, color="red", linestyle="dashed", label="mean") plt.axvline(sample_mean + sample_std, color="green", linestyle="dotted") plt.axvline(sample_mean - sample_std, linestyle="dotted", color="green") + plt.title("Weighted Posterior Histogram of " + str(name)) + plt.xlabel(name) + plt.ylabel("Probability Density") plt.legend() - plt.plot() + if axis is not None: + plt.axis(axis) - def corner_plot(self, param1: str, param2: str, n1=0, n2=0, rkey=random.PRNGKey(1234)): + if save: + if filename is None: + plt.savefig(str(name) + "_Weighted_Posterior_plot.png") + else: + plt.savefig(filename) + return plt + + def corner_plot( + self, + param1: str, + param2: str, + n1=0, + n2=0, + rkey=random.PRNGKey(1234), + axis=None, + save=False, + filename=None, + ): """ - Plots the corner plot for the given parameters + Plots the corner plot between two given parameters + + Parameters + ---------- + param1 : str + Name of the first parameter. + Should be from the names of the parameter list used or from the names of parameters + used in the prior_function + + param2 : str + Name of the second parameter. + Should be from the names of the parameter list used or from the names of parameters + used in the prior_function + + n1 : int, default 0 + The index of the first parameter to be plotted. + For multivariate parameters, the index of the specific parameter to be plotted. + + n2 : int, default 0 + The index of the second parameter to be plotted. + For multivariate parameters, the index of the specific parameter to be plotted. + + key: jax.random.PRNGKey, default ``random.PRNGKey(1234)`` + Random key for the shuffling the weights + + axis : list, tuple, string, default ``None`` + Parameter to set axis properties of ``matplotlib`` figure. For example + it can be a list like ``[xmin, xmax, ymin, ymax]`` or any other + acceptable argument for ``matplotlib.pyplot.axis()`` method. + + save : bool, optionalm, default ``False`` + If ``True``, save the figure with specified filename. + + filename : str + File name and path of the image to save. Depends on the boolean ``save``. + + Returns + ------- + plt : ``matplotlib.pyplot`` object + Reference to plot, call ``show()`` to display it """ nsamples = self.Results.total_num_samples log_p = self.Results.log_dp_mean @@ -648,4 +780,17 @@ def corner_plot(self, param1: str, param2: str, n1=0, n2=0, rkey=random.PRNGKey( density=True, cmap="GnBu", ) - plt.plot() + plt.title("Corner Plot of " + str(param1) + " and " + str(param2)) + plt.xlabel(param2) + plt.ylabel(param1) + plt.colorbar() + if axis is not None: + plt.axis(axis) + + if save: + if filename is None: + plt.savefig(str(param1) + "_" + str(param2) + "_Corner_plot.png") + else: + plt.savefig(filename) + + return plt diff --git a/stingray/modeling/tests/test_gpmodeling.py b/stingray/modeling/tests/test_gpmodeling.py index 01ee0a398..992e34045 100644 --- a/stingray/modeling/tests/test_gpmodeling.py +++ b/stingray/modeling/tests/test_gpmodeling.py @@ -1,12 +1,13 @@ +import os +import numpy as np +import matplotlib.pyplot as plt + import jax import jax.numpy as jnp from jax import random jax.config.update("jax_enable_x64", True) -import numpy as np -import matplotlib.pyplot as plt - from tinygp import GaussianProcess, kernels from stingray.modeling.gpmodeling import get_kernel, get_mean, get_gp_params from stingray.modeling.gpmodeling import get_prior, get_likelihood, GPResult @@ -19,6 +20,12 @@ from jaxns import ExactNestedSampler, TerminationCondition, Prior, Model +def clear_all_figs(): + fign = plt.get_fignums() + for fig in fign: + plt.close(fig) + + class Testget_kernel(object): def setup_class(self): self.x = np.linspace(0, 1, 5) @@ -271,3 +278,83 @@ def test_plot_cornerplot(self): def test_get_parameters_names(self): assert sorted(self.params_list) == self.gpresult.get_parameters_names() + + def test_print_summary(self): + self.gpresult.print_summary() + assert True + + def test_max_posterior_parameters(self): + for key in self.params_list: + assert key in self.gpresult.get_max_posterior_parameters() + + def test_max_likelihood_parameters(self): + for key in self.params_list: + assert key in self.gpresult.get_max_likelihood_parameters() + + def test_posterior_plot(self): + self.gpresult.posterior_plot("A") + assert plt.fignum_exists(1) + + def test_posterior_plot_labels_and_fname_default(self): + clear_all_figs() + outfname = "A_Posterior_plot.png" + if os.path.exists(outfname): + os.unlink(outfname) + self.gpresult.posterior_plot("A", save=True) + assert os.path.exists(outfname) + os.unlink(outfname) + + def test_posterior_plot_labels_and_fname(self): + clear_all_figs() + outfname = "blabla.png" + if os.path.exists(outfname): + os.unlink(outfname) + self.gpresult.posterior_plot("A", axis=[0, 14, 0, 0.5], save=True, filename=outfname) + assert os.path.exists(outfname) + os.unlink(outfname) + + def test_weighted_posterior_plot(self): + self.gpresult.weighted_posterior_plot("A") + assert plt.fignum_exists(1) + + def test_weighted_posterior_plot_labels_and_fname_default(self): + clear_all_figs() + outfname = "A_Weighted_Posterior_plot.png" + if os.path.exists(outfname): + os.unlink(outfname) + self.gpresult.weighted_posterior_plot("A", save=True) + assert os.path.exists(outfname) + os.unlink(outfname) + + def test_weighted_posterior_plot_labels_and_fname(self): + clear_all_figs() + outfname = "blabla.png" + if os.path.exists(outfname): + os.unlink(outfname) + self.gpresult.weighted_posterior_plot( + "A", axis=[0, 14, 0, 0.5], save=True, filename=outfname + ) + assert os.path.exists(outfname) + os.unlink(outfname) + + def test_corner_plot(self): + self.gpresult.corner_plot("A", "t0") + assert plt.fignum_exists(1) + + def test_corner_plot_labels_and_fname_default(self): + clear_all_figs() + outfname = "A_t0_Corner_plot.png" + if os.path.exists(outfname): + os.unlink(outfname) + self.gpresult.corner_plot("A", "t0", save=True) + assert os.path.exists(outfname) + os.unlink(outfname) + + def test_corner_plot_labels_and_fname(self): + clear_all_figs() + outfname = "blabla.png" + if os.path.exists(outfname): + os.unlink(outfname) + self.gpresult.corner_plot("A", "t0", axis=[0, 0.5, 0, 5], save=True, filename=outfname) + assert os.path.exists(outfname) + os.unlink(outfname) From 3bf207924d60efb89429f827b975371488c1300c Mon Sep 17 00:00:00 2001 From: Gaurav17Joshi Date: Wed, 26 Jul 2023 20:40:31 +0530 Subject: [PATCH 36/50] Added testing skips --- setup.cfg | 12 +- stingray/modeling/gpmodeling.py | 17 +- stingray/modeling/tests/test_gpmodeling.py | 205 ++++++++++++--------- 3 files changed, 134 insertions(+), 100 deletions(-) diff --git a/setup.cfg b/setup.cfg index e15312b2e..8b5adedb1 100644 --- a/setup.cfg +++ b/setup.cfg @@ -22,12 +22,6 @@ install_requires = scipy>=1.1.0 ; Matplotlib 3.4.0 is incompatible with Astropy matplotlib>=3.0,!=3.4.0 - jax - tinygp - jaxns - etils - tensorflow_probability - typing_extensions [options.entry_points] console_scripts = @@ -50,6 +44,12 @@ all = xarray pandas ultranest + jax + tinygp + jaxns + etils + tensorflow_probability + typing_extensions test = pytest pytest-astropy diff --git a/stingray/modeling/gpmodeling.py b/stingray/modeling/gpmodeling.py index d2e7f2917..6841cc57b 100644 --- a/stingray/modeling/gpmodeling.py +++ b/stingray/modeling/gpmodeling.py @@ -1,3 +1,4 @@ +import pytest import numpy as np import matplotlib.pyplot as plt import functools @@ -6,12 +7,15 @@ try: import jax except ImportError: - raise ImportError("Jax not installed") + pytest.skip(allow_module_level=True) -import jax.numpy as jnp -from jax import jit, random +try: + import jax.numpy as jnp + from jax import jit, random -jax.config.update("jax_enable_x64", True) + jax.config.update("jax_enable_x64", True) +except ImportError: + raise ImportError("Jax not installed") try: from tinygp import GaussianProcess, kernels @@ -376,6 +380,11 @@ def get_prior(params_list, prior_dict): -------- A prior function for a Red Noise kernel and a Gaussian mean function Obain the parameters list + >>> if not can_sample: + ... pytest.skip("Jaxns not installed. Cannot make jaxns specific prior.") + >>> if not tfp_available: + ... pytest.skip("Tensorflow probability required to make priors.") + >>> params_list = get_gp_params("RN", "gaussian") Make a prior dictionary using tensorflow_probability distributions diff --git a/stingray/modeling/tests/test_gpmodeling.py b/stingray/modeling/tests/test_gpmodeling.py index 992e34045..74af0b2a0 100644 --- a/stingray/modeling/tests/test_gpmodeling.py +++ b/stingray/modeling/tests/test_gpmodeling.py @@ -1,23 +1,43 @@ import os +import pytest import numpy as np import matplotlib.pyplot as plt -import jax -import jax.numpy as jnp -from jax import random +try: + import jax + import jax.numpy as jnp + from jax import random -jax.config.update("jax_enable_x64", True) + jax.config.update("jax_enable_x64", True) +except ImportError: + pytest.skip(allow_module_level=True) + +_HAS_TINYGP = True +_HAS_TFP = True +_HAS_JAXNS = True + +try: + import tinygp + from tinygp import GaussianProcess, kernels +except ImportError: + _HAS_TINYGP = False -from tinygp import GaussianProcess, kernels from stingray.modeling.gpmodeling import get_kernel, get_mean, get_gp_params from stingray.modeling.gpmodeling import get_prior, get_likelihood, GPResult from stingray import Lightcurve -import tensorflow_probability.substrates.jax as tfp +try: + import tensorflow_probability.substrates.jax as tfp -tfpd = tfp.distributions + tfpd = tfp.distributions +except ImportError: + _HAS_TFP = False -from jaxns import ExactNestedSampler, TerminationCondition, Prior, Model +try: + import jaxns + from jaxns import ExactNestedSampler, TerminationCondition, Prior, Model +except ImportError: + _HAS_JAXNS = False def clear_all_figs(): @@ -26,6 +46,7 @@ def clear_all_figs(): plt.close(fig) +@pytest.mark.skipif(not _HAS_TINYGP, reason="tinygp not installed") class Testget_kernel(object): def setup_class(self): self.x = np.linspace(0, 1, 5) @@ -126,89 +147,93 @@ def test_get_mean_fred(self): ) + 4.0 * jnp.exp(-5.0 * ((self.t + 0.4) / 0.7 + 0.7 / (self.t + 0.4))) * jnp.exp(2 * 5.0) assert (get_mean("fred", self.fred_mean_params)(self.t) == result_fred).all() - class Testget_gp_params(object): - def setup_class(self): - pass - - def test_get_gp_params_rn(self): - assert get_gp_params("RN", "gaussian") == ["arn", "crn", "A", "t0", "sig"] - assert get_gp_params("RN", "constant") == ["arn", "crn", "A"] - assert get_gp_params("RN", "skew_gaussian") == ["arn", "crn", "A", "t0", "sig1", "sig2"] - assert get_gp_params("RN", "skew_exponential") == [ - "arn", - "crn", - "A", - "t0", - "sig1", - "sig2", - ] - assert get_gp_params("RN", "exponential") == ["arn", "crn", "A", "t0", "sig"] - assert get_gp_params("RN", "fred") == ["arn", "crn", "A", "t0", "delta", "phi"] - - def test_get_gp_params_qpo_plus_rn(self): - assert get_gp_params("QPO_plus_RN", "gaussian") == [ - "arn", - "crn", - "aqpo", - "cqpo", - "freq", - "A", - "t0", - "sig", - ] - assert get_gp_params("QPO_plus_RN", "constant") == [ - "arn", - "crn", - "aqpo", - "cqpo", - "freq", - "A", - ] - assert get_gp_params("QPO_plus_RN", "skew_gaussian") == [ - "arn", - "crn", - "aqpo", - "cqpo", - "freq", - "A", - "t0", - "sig1", - "sig2", - ] - assert get_gp_params("QPO_plus_RN", "skew_exponential") == [ - "arn", - "crn", - "aqpo", - "cqpo", - "freq", - "A", - "t0", - "sig1", - "sig2", - ] - assert get_gp_params("QPO_plus_RN", "exponential") == [ - "arn", - "crn", - "aqpo", - "cqpo", - "freq", - "A", - "t0", - "sig", - ] - assert get_gp_params("QPO_plus_RN", "fred") == [ - "arn", - "crn", - "aqpo", - "cqpo", - "freq", - "A", - "t0", - "delta", - "phi", - ] - +class Testget_gp_params(object): + def setup_class(self): + pass + + def test_get_gp_params_rn(self): + assert get_gp_params("RN", "gaussian") == ["arn", "crn", "A", "t0", "sig"] + assert get_gp_params("RN", "constant") == ["arn", "crn", "A"] + assert get_gp_params("RN", "skew_gaussian") == ["arn", "crn", "A", "t0", "sig1", "sig2"] + assert get_gp_params("RN", "skew_exponential") == [ + "arn", + "crn", + "A", + "t0", + "sig1", + "sig2", + ] + assert get_gp_params("RN", "exponential") == ["arn", "crn", "A", "t0", "sig"] + assert get_gp_params("RN", "fred") == ["arn", "crn", "A", "t0", "delta", "phi"] + + def test_get_gp_params_qpo_plus_rn(self): + assert get_gp_params("QPO_plus_RN", "gaussian") == [ + "arn", + "crn", + "aqpo", + "cqpo", + "freq", + "A", + "t0", + "sig", + ] + assert get_gp_params("QPO_plus_RN", "constant") == [ + "arn", + "crn", + "aqpo", + "cqpo", + "freq", + "A", + ] + assert get_gp_params("QPO_plus_RN", "skew_gaussian") == [ + "arn", + "crn", + "aqpo", + "cqpo", + "freq", + "A", + "t0", + "sig1", + "sig2", + ] + assert get_gp_params("QPO_plus_RN", "skew_exponential") == [ + "arn", + "crn", + "aqpo", + "cqpo", + "freq", + "A", + "t0", + "sig1", + "sig2", + ] + assert get_gp_params("QPO_plus_RN", "exponential") == [ + "arn", + "crn", + "aqpo", + "cqpo", + "freq", + "A", + "t0", + "sig", + ] + assert get_gp_params("QPO_plus_RN", "fred") == [ + "arn", + "crn", + "aqpo", + "cqpo", + "freq", + "A", + "t0", + "delta", + "phi", + ] + + +@pytest.mark.skipif( + not (_HAS_TINYGP and _HAS_TFP and _HAS_JAXNS), reason="tinygp, tfp or jaxns not installed" +) class TestGPResult(object): def setup_class(self): self.Times = np.linspace(0, 1, 64) From a14ca4a51d4a657848563d0d8d5ab0092d683787 Mon Sep 17 00:00:00 2001 From: Gaurav17Joshi Date: Mon, 14 Aug 2023 15:38:39 +0530 Subject: [PATCH 37/50] Jax import changed --- stingray/modeling/gpmodeling.py | 63 +++++++++++++++++++++------------ 1 file changed, 40 insertions(+), 23 deletions(-) diff --git a/stingray/modeling/gpmodeling.py b/stingray/modeling/gpmodeling.py index 6841cc57b..40977280e 100644 --- a/stingray/modeling/gpmodeling.py +++ b/stingray/modeling/gpmodeling.py @@ -1,4 +1,3 @@ -import pytest import numpy as np import matplotlib.pyplot as plt import functools @@ -6,16 +5,13 @@ try: import jax -except ImportError: - pytest.skip(allow_module_level=True) - -try: import jax.numpy as jnp from jax import jit, random jax.config.update("jax_enable_x64", True) + jax_avail = True except ImportError: - raise ImportError("Jax not installed") + jax_avail = False try: from tinygp import GaussianProcess, kernels @@ -60,6 +56,9 @@ def get_kernel(kernel_type, kernel_params): Should contain the parameters for the selected kernel """ + if not jax_avail: + raise ImportError("Jax is required") + if not can_make_gp: raise ImportError("Tinygp is required to make kernels") @@ -95,6 +94,9 @@ def get_mean(mean_type, mean_params): Should contain the parameters for the selected mean """ + if not jax_avail: + raise ImportError("Jax is required") + if mean_type == "gaussian": mean = functools.partial(_gaussian, mean_params=mean_params) elif mean_type == "exponential": @@ -340,7 +342,7 @@ def get_gp_params(kernel_type, mean_type): Examples -------- - >>> get_gp_params("QPO_plus_RN", "gaussian") + get_gp_params("QPO_plus_RN", "gaussian") ['arn', 'crn', 'aqpo', 'cqpo', 'freq', 'A', 't0', 'sig'] """ kernel_params = _get_kernel_params(kernel_type) @@ -380,25 +382,28 @@ def get_prior(params_list, prior_dict): -------- A prior function for a Red Noise kernel and a Gaussian mean function Obain the parameters list - >>> if not can_sample: - ... pytest.skip("Jaxns not installed. Cannot make jaxns specific prior.") - >>> if not tfp_available: - ... pytest.skip("Tensorflow probability required to make priors.") + if not can_sample: + pytest.skip("Jaxns not installed. Cannot make jaxns specific prior.") + if not tfp_available: + pytest.skip("Tensorflow probability required to make priors.") - >>> params_list = get_gp_params("RN", "gaussian") + params_list = get_gp_params("RN", "gaussian") Make a prior dictionary using tensorflow_probability distributions - >>> prior_dict = { - ... "A": tfpd.Uniform(low = 1e-1, high = 2e+2), - ... "t0": tfpd.Uniform(low = 0.0 - 0.1, high = 1 + 0.1), - ... "sig": tfpd.Uniform(low = 0.5 * 1 / 20, high = 2 ), - ... "arn": tfpd.Uniform(low = 0.1 , high = 2 ), - ... "crn": tfpd.Uniform(low = jnp.log(1 /5), high = jnp.log(20)), - ... } + prior_dict = { + "A": tfpd.Uniform(low = 1e-1, high = 2e+2), + "t0": tfpd.Uniform(low = 0.0 - 0.1, high = 1 + 0.1), + "sig": tfpd.Uniform(low = 0.5 * 1 / 20, high = 2 ), + "arn": tfpd.Uniform(low = 0.1 , high = 2 ), + "crn": tfpd.Uniform(low = jnp.log(1 /5), high = jnp.log(20)), + } - >>> prior_model = get_prior(params_list, prior_dict) + prior_model = get_prior(params_list, prior_dict) """ + if not jax_avail: + raise ImportError("Jax is required") + if not can_sample: raise ImportError("Jaxns not installed. Cannot make jaxns specific prior.") @@ -454,6 +459,9 @@ def get_likelihood(params_list, kernel_type, mean_type, **kwargs): The jaxns specific log likelihood function. """ + if not jax_avail: + raise ImportError("Jax is required") + if not can_make_gp: raise ImportError("Tinygp is required to make the GP model.") @@ -516,6 +524,9 @@ def sample(self, prior_model=None, likelihood_model=None, **kwargs): The results of the nested sampling process """ + if not jax_avail: + raise ImportError("Jax is required") + if not can_sample: raise ImportError("Jaxns not installed! Can't sample!") @@ -600,7 +611,7 @@ def posterior_plot(self, name: str, n=0, axis=None, save=False, filename=None): it can be a list like ``[xmin, xmax, ymin, ymax]`` or any other acceptable argument for ``matplotlib.pyplot.axis()`` method. - save : bool, optionalm, default ``False`` + save : bool, optional, default ``False`` If ``True``, save the figure with specified filename. filename : str @@ -638,7 +649,7 @@ def posterior_plot(self, name: str, n=0, axis=None, save=False, filename=None): return plt def weighted_posterior_plot( - self, name: str, n=0, rkey=random.PRNGKey(1234), axis=None, save=False, filename=None + self, name: str, n=0, rkey=None, axis=None, save=False, filename=None ): """ Returns the weighted posterior histogram for the given parameter @@ -673,6 +684,9 @@ def weighted_posterior_plot( plt : ``matplotlib.pyplot`` object Reference to plot, call ``show()`` to display it """ + if rkey is None: + rkey = random.PRNGKey(1234) + nsamples = self.Results.total_num_samples log_p = self.Results.log_dp_mean samples = self.Results.samples[name].reshape((nsamples, -1))[:, n] @@ -720,7 +734,7 @@ def corner_plot( param2: str, n1=0, n2=0, - rkey=random.PRNGKey(1234), + rkey=None, axis=None, save=False, filename=None, @@ -767,6 +781,9 @@ def corner_plot( plt : ``matplotlib.pyplot`` object Reference to plot, call ``show()`` to display it """ + if rkey is None: + rkey = random.PRNGKey(1234) + nsamples = self.Results.total_num_samples log_p = self.Results.log_dp_mean samples1 = self.Results.samples[param1].reshape((nsamples, -1))[:, n1] From f3bef3f86ad3b41635cf0354c154fb952c8dc302 Mon Sep 17 00:00:00 2001 From: Gaurav17Joshi Date: Thu, 17 Aug 2023 14:01:42 +0530 Subject: [PATCH 38/50] added kernel and warnings --- stingray/modeling/gpmodeling.py | 12 ++++++++++++ stingray/modeling/tests/test_gpmodeling.py | 20 ++++++++++++++++++++ 2 files changed, 32 insertions(+) diff --git a/stingray/modeling/gpmodeling.py b/stingray/modeling/gpmodeling.py index 40977280e..051cf0b5c 100644 --- a/stingray/modeling/gpmodeling.py +++ b/stingray/modeling/gpmodeling.py @@ -77,6 +77,16 @@ def get_kernel(kernel_type, kernel_params): scale=1 / kernel_params["crn"], sigma=(kernel_params["arn"]) ** 0.5 ) return kernel + elif kernel_type == "QPO": + kernel = kernels.quasisep.Celerite( + a=kernel_params["aqpo"], + b=0.0, + c=kernel_params["cqpo"], + d=2 * jnp.pi * kernel_params["freq"], + ) + return kernel + else: + raise ValueError("Kernel type not implemented") def get_mean(mean_type, mean_params): @@ -109,6 +119,8 @@ def get_mean(mean_type, mean_params): mean = functools.partial(_skew_exponential, mean_params=mean_params) elif mean_type == "fred": mean = functools.partial(_fred, mean_params=mean_params) + else: + raise ValueError("Mean type not implemented") return mean diff --git a/stingray/modeling/tests/test_gpmodeling.py b/stingray/modeling/tests/test_gpmodeling.py index 74af0b2a0..a5a0dbd64 100644 --- a/stingray/modeling/tests/test_gpmodeling.py +++ b/stingray/modeling/tests/test_gpmodeling.py @@ -74,6 +74,22 @@ def test_get_kernel_rn(self): kernel_rn(self.x, jnp.array([0.0])) == kernel_rn_test(self.x, jnp.array([0.0])) ).all() + def test_get_kernel_qpo(self): + kernel_qpo = kernels.quasisep.Celerite( + a=1, + b=0.0, + c=1, + d=2 * jnp.pi * 1, + ) + kernel_qpo_test = get_kernel("QPO", self.kernel_params) + assert ( + kernel_qpo(self.x, jnp.array([0.0])) == kernel_qpo_test(self.x, jnp.array([0.0])) + ).all() + + def test_value_error(self): + with pytest.raises(ValueError, match="Kernel type not implemented"): + get_kernel("periodic", self.kernel_params) + class Testget_mean(object): def setup_class(self): @@ -147,6 +163,10 @@ def test_get_mean_fred(self): ) + 4.0 * jnp.exp(-5.0 * ((self.t + 0.4) / 0.7 + 0.7 / (self.t + 0.4))) * jnp.exp(2 * 5.0) assert (get_mean("fred", self.fred_mean_params)(self.t) == result_fred).all() + def test_value_error(self): + with pytest.raises(ValueError, match="Mean type not implemented"): + get_mean("polynomial", self.mean_params) + class Testget_gp_params(object): def setup_class(self): From 181e90a67e252ee318c1eb08200af240370e5c1a Mon Sep 17 00:00:00 2001 From: Gaurav17Joshi Date: Sun, 20 Aug 2023 14:11:44 +0530 Subject: [PATCH 39/50] Adding max samples option --- stingray/modeling/gpmodeling.py | 7 +++++-- stingray/modeling/tests/test_gpmodeling.py | 6 ++++-- 2 files changed, 9 insertions(+), 4 deletions(-) diff --git a/stingray/modeling/gpmodeling.py b/stingray/modeling/gpmodeling.py index 051cf0b5c..f85f13f3f 100644 --- a/stingray/modeling/gpmodeling.py +++ b/stingray/modeling/gpmodeling.py @@ -516,7 +516,7 @@ def __init__(self, Lc: Lightcurve) -> None: self.counts = Lc.counts self.Result = None - def sample(self, prior_model=None, likelihood_model=None, **kwargs): + def sample(self, prior_model=None, likelihood_model=None, max_samples=1e4): """ Makes a Jaxns nested sampler over the Gaussian Process, given the prior and likelihood model @@ -530,6 +530,9 @@ def sample(self, prior_model=None, likelihood_model=None, **kwargs): A likelihood fucntion which takes in the arguments of the prior model and returns the loglikelihood of the model + max_samples: int, default 1e4 + The maximum number of samples to be taken by the nested sampler + Returns ---------- Results: jaxns.results.NestedSamplerResults object @@ -548,7 +551,7 @@ def sample(self, prior_model=None, likelihood_model=None, **kwargs): NSmodel = Model(prior_model=self.prior_model, log_likelihood=self.likelihood_model) NSmodel.sanity_check(random.PRNGKey(10), S=100) - self.Exact_ns = ExactNestedSampler(NSmodel, num_live_points=500, max_samples=1e4) + self.Exact_ns = ExactNestedSampler(NSmodel, num_live_points=500, max_samples=max_samples) Termination_reason, State = self.Exact_ns( random.PRNGKey(42), term_cond=TerminationCondition(live_evidence_frac=1e-4) ) diff --git a/stingray/modeling/tests/test_gpmodeling.py b/stingray/modeling/tests/test_gpmodeling.py index a5a0dbd64..e65dff3f8 100644 --- a/stingray/modeling/tests/test_gpmodeling.py +++ b/stingray/modeling/tests/test_gpmodeling.py @@ -297,14 +297,16 @@ def setup_class(self): NSmodel = Model(prior_model=prior_model, log_likelihood=likelihood_model) NSmodel.sanity_check(random.PRNGKey(10), S=100) - Exact_ns = ExactNestedSampler(NSmodel, num_live_points=500, max_samples=1e4) + Exact_ns = ExactNestedSampler(NSmodel, num_live_points=500, max_samples=5e3) Termination_reason, State = Exact_ns( random.PRNGKey(42), term_cond=TerminationCondition(live_evidence_frac=1e-4) ) self.Results = Exact_ns.to_results(State, Termination_reason) self.gpresult = GPResult(lc) - self.gpresult.sample(prior_model=prior_model, likelihood_model=likelihood_model) + self.gpresult.sample( + prior_model=prior_model, likelihood_model=likelihood_model, max_samples=5e3 + ) def test_sample(self): for key in self.params_list: From b374d6bc27809a146b1c8e3ac4aa6deb99089fc8 Mon Sep 17 00:00:00 2001 From: Gaurav17Joshi Date: Sun, 20 Aug 2023 17:54:38 +0530 Subject: [PATCH 40/50] Added log parameters --- stingray/modeling/gpmodeling.py | 39 ++++-- stingray/modeling/tests/test_gpmodeling.py | 144 +++++++++------------ 2 files changed, 90 insertions(+), 93 deletions(-) diff --git a/stingray/modeling/gpmodeling.py b/stingray/modeling/gpmodeling.py index f85f13f3f..134ddf145 100644 --- a/stingray/modeling/gpmodeling.py +++ b/stingray/modeling/gpmodeling.py @@ -301,15 +301,20 @@ def _get_kernel_params(kernel_type): ---------- kernel_type: string The type of kernel to be used for the Gaussian Process model + The parameters in log scale have a prefix of "log_" Returns ------- A list of the parameters for the kernel for the GP model """ if kernel_type == "RN": - return ["arn", "crn"] + return ["log_arn", "log_crn"] elif kernel_type == "QPO_plus_RN": - return ["arn", "crn", "aqpo", "cqpo", "freq"] + return ["log_arn", "log_crn", "log_aqpo", "log_cqpo", "log_freq"] + elif kernel_type == "QPO": + return ["log_aqpo", "log_cqpo", "log_freq"] + else: + raise ValueError("Kernel type not implemented") def _get_mean_params(mean_type): @@ -320,19 +325,22 @@ def _get_mean_params(mean_type): ---------- mean_type: string The type of mean to be used for the Gaussian Process model + The parameters in log scale have a prefix of "log_" Returns ------- A list of the parameters for the mean for the GP model """ if (mean_type == "gaussian") or (mean_type == "exponential"): - return ["A", "t0", "sig"] + return ["log_A", "t0", "log_sig"] elif mean_type == "constant": - return ["A"] + return ["log_A"] elif (mean_type == "skew_gaussian") or (mean_type == "skew_exponential"): - return ["A", "t0", "sig1", "sig2"] + return ["log_A", "t0", "log_sig1", "log_sig2"] elif mean_type == "fred": - return ["A", "t0", "delta", "phi"] + return ["log_A", "t0", "delta", "phi"] + else: + raise ValueError("Mean type not implemented") def get_gp_params(kernel_type, mean_type): @@ -355,7 +363,7 @@ def get_gp_params(kernel_type, mean_type): Examples -------- get_gp_params("QPO_plus_RN", "gaussian") - ['arn', 'crn', 'aqpo', 'cqpo', 'freq', 'A', 't0', 'sig'] + ['log_arn', 'log_crn', 'log_aqpo', 'log_cqpo', 'log_freq', 'log_A', 't0', 'log_sig'] """ kernel_params = _get_kernel_params(kernel_type) mean_params = _get_mean_params(mean_type) @@ -381,6 +389,7 @@ def get_prior(params_list, prior_dict): or special priors from jaxns. **Note**: If jaxns priors are used, then the name given to them should be the same as the corresponding name in the params_list. + Also, if a parameter is to be used in the log scale, it should have a prefix of "log_" Returns ------- @@ -403,11 +412,11 @@ def get_prior(params_list, prior_dict): Make a prior dictionary using tensorflow_probability distributions prior_dict = { - "A": tfpd.Uniform(low = 1e-1, high = 2e+2), + "log_A": tfpd.Uniform(low = jnp.log(1e-1), high = jnp.log(2e+2)), "t0": tfpd.Uniform(low = 0.0 - 0.1, high = 1 + 0.1), - "sig": tfpd.Uniform(low = 0.5 * 1 / 20, high = 2 ), - "arn": tfpd.Uniform(low = 0.1 , high = 2 ), - "crn": tfpd.Uniform(low = jnp.log(1 /5), high = jnp.log(20)), + "log_sig": tfpd.Uniform(low = jnp.log(0.5 * 1 / 20), high = jnp.log(2) ), + "log_arn": tfpd.Uniform(low = jnp.log(0.1) , high = jnp.log(2) ), + "log_crn": tfpd.Uniform(low = jnp.log(1 /5), high = jnp.log(20)), } prior_model = get_prior(params_list, prior_dict) @@ -441,7 +450,8 @@ def get_likelihood(params_list, kernel_type, mean_type, **kwargs): Makes a jaxns specific log likelihood function which takes in the parameters in the order of the parameters list, and calculates the log likelihood of the data given the parameters, and the model - (kernel, mean) of the GP model. + (kernel, mean) of the GP model. **Note** Any parameters with a prefix + of "log_" are taken to be in the log scale. Parameters ---------- @@ -481,7 +491,10 @@ def get_likelihood(params_list, kernel_type, mean_type, **kwargs): def likelihood_model(*args): dict = {} for i, params in enumerate(params_list): - dict[params] = args[i] + if params[0:4] == "log_": + dict[params[4:]] = jnp.exp(args[i]) + else: + dict[params] = args[i] kernel = get_kernel(kernel_type=kernel_type, kernel_params=dict) mean = get_mean(mean_type=mean_type, mean_params=dict) gp = GaussianProcess(kernel, kwargs["Times"], mean_value=mean(kwargs["Times"])) diff --git a/stingray/modeling/tests/test_gpmodeling.py b/stingray/modeling/tests/test_gpmodeling.py index e65dff3f8..1d7f5eecd 100644 --- a/stingray/modeling/tests/test_gpmodeling.py +++ b/stingray/modeling/tests/test_gpmodeling.py @@ -173,81 +173,65 @@ def setup_class(self): pass def test_get_gp_params_rn(self): - assert get_gp_params("RN", "gaussian") == ["arn", "crn", "A", "t0", "sig"] - assert get_gp_params("RN", "constant") == ["arn", "crn", "A"] - assert get_gp_params("RN", "skew_gaussian") == ["arn", "crn", "A", "t0", "sig1", "sig2"] - assert get_gp_params("RN", "skew_exponential") == [ - "arn", - "crn", - "A", + assert get_gp_params("RN", "gaussian") == ["log_arn", "log_crn", "log_A", "t0", "log_sig"] + assert get_gp_params("RN", "constant") == ["log_arn", "log_crn", "log_A"] + assert get_gp_params("RN", "skew_gaussian") == [ + "log_arn", + "log_crn", + "log_A", "t0", - "sig1", - "sig2", + "log_sig1", + "log_sig2", ] - assert get_gp_params("RN", "exponential") == ["arn", "crn", "A", "t0", "sig"] - assert get_gp_params("RN", "fred") == ["arn", "crn", "A", "t0", "delta", "phi"] - - def test_get_gp_params_qpo_plus_rn(self): - assert get_gp_params("QPO_plus_RN", "gaussian") == [ - "arn", - "crn", - "aqpo", - "cqpo", - "freq", - "A", + assert get_gp_params("RN", "skew_exponential") == [ + "log_arn", + "log_crn", + "log_A", "t0", - "sig", - ] - assert get_gp_params("QPO_plus_RN", "constant") == [ - "arn", - "crn", - "aqpo", - "cqpo", - "freq", - "A", + "log_sig1", + "log_sig2", ] - assert get_gp_params("QPO_plus_RN", "skew_gaussian") == [ - "arn", - "crn", - "aqpo", - "cqpo", - "freq", - "A", + assert get_gp_params("RN", "exponential") == [ + "log_arn", + "log_crn", + "log_A", "t0", - "sig1", - "sig2", + "log_sig", ] - assert get_gp_params("QPO_plus_RN", "skew_exponential") == [ - "arn", - "crn", - "aqpo", - "cqpo", - "freq", - "A", + assert get_gp_params("RN", "fred") == [ + "log_arn", + "log_crn", + "log_A", "t0", - "sig1", - "sig2", + "delta", + "phi", ] - assert get_gp_params("QPO_plus_RN", "exponential") == [ - "arn", - "crn", - "aqpo", - "cqpo", - "freq", - "A", + + def test_get_gp_params_qpo_plus_rn(self): + assert get_gp_params("QPO_plus_RN", "gaussian") == [ + "log_arn", + "log_crn", + "log_aqpo", + "log_cqpo", + "log_freq", + "log_A", "t0", - "sig", + "log_sig", ] - assert get_gp_params("QPO_plus_RN", "fred") == [ - "arn", - "crn", - "aqpo", - "cqpo", - "freq", - "A", + with pytest.raises(ValueError, match="Mean type not implemented"): + get_gp_params("QPO_plus_RN", "notimplemented") + + with pytest.raises(ValueError, match="Kernel type not implemented"): + get_gp_params("notimplemented", "gaussian") + + def test_get_qpo(self): + assert get_gp_params("QPO", "gaussian") == [ + "log_aqpo", + "log_cqpo", + "log_freq", + "log_A", "t0", - "delta", - "phi", + "log_sig", ] @@ -278,11 +262,11 @@ def setup_class(self): # The prior dictionary, with suitable tfpd prior distributions prior_dict = { - "A": tfpd.Uniform(low=0.1 * span, high=2 * span), + "log_A": tfpd.Uniform(low=jnp.log(0.1 * span), high=jnp.log(2 * span)), "t0": tfpd.Uniform(low=self.Times[0] - 0.1 * T, high=self.Times[-1] + 0.1 * T), - "sig": tfpd.Uniform(low=0.5 * 1 / f, high=2 * T), - "arn": tfpd.Uniform(low=0.1 * span, high=2 * span), - "crn": tfpd.Uniform(low=jnp.log(1 / T), high=jnp.log(f)), + "log_sig": tfpd.Uniform(low=jnp.log(0.5 * 1 / f), high=jnp.log(2 * T)), + "log_arn": tfpd.Uniform(low=jnp.log(0.1 * span), high=jnp.log(2 * span)), + "log_crn": tfpd.Uniform(low=jnp.log(1 / T), high=jnp.log(f)), } prior_model = get_prior(self.params_list, prior_dict) @@ -339,15 +323,15 @@ def test_max_likelihood_parameters(self): assert key in self.gpresult.get_max_likelihood_parameters() def test_posterior_plot(self): - self.gpresult.posterior_plot("A") + self.gpresult.posterior_plot("log_A") assert plt.fignum_exists(1) def test_posterior_plot_labels_and_fname_default(self): clear_all_figs() - outfname = "A_Posterior_plot.png" + outfname = "log_A_Posterior_plot.png" if os.path.exists(outfname): os.unlink(outfname) - self.gpresult.posterior_plot("A", save=True) + self.gpresult.posterior_plot("log_A", save=True) assert os.path.exists(outfname) os.unlink(outfname) @@ -356,20 +340,20 @@ def test_posterior_plot_labels_and_fname(self): outfname = "blabla.png" if os.path.exists(outfname): os.unlink(outfname) - self.gpresult.posterior_plot("A", axis=[0, 14, 0, 0.5], save=True, filename=outfname) + self.gpresult.posterior_plot("log_A", axis=[0, 14, 0, 0.5], save=True, filename=outfname) assert os.path.exists(outfname) os.unlink(outfname) def test_weighted_posterior_plot(self): - self.gpresult.weighted_posterior_plot("A") + self.gpresult.weighted_posterior_plot("log_A") assert plt.fignum_exists(1) def test_weighted_posterior_plot_labels_and_fname_default(self): clear_all_figs() - outfname = "A_Weighted_Posterior_plot.png" + outfname = "log_A_Weighted_Posterior_plot.png" if os.path.exists(outfname): os.unlink(outfname) - self.gpresult.weighted_posterior_plot("A", save=True) + self.gpresult.weighted_posterior_plot("log_A", save=True) assert os.path.exists(outfname) os.unlink(outfname) @@ -379,21 +363,21 @@ def test_weighted_posterior_plot_labels_and_fname(self): if os.path.exists(outfname): os.unlink(outfname) self.gpresult.weighted_posterior_plot( - "A", axis=[0, 14, 0, 0.5], save=True, filename=outfname + "log_A", axis=[0, 14, 0, 0.5], save=True, filename=outfname ) assert os.path.exists(outfname) os.unlink(outfname) def test_corner_plot(self): - self.gpresult.corner_plot("A", "t0") + self.gpresult.corner_plot("log_A", "t0") assert plt.fignum_exists(1) def test_corner_plot_labels_and_fname_default(self): clear_all_figs() - outfname = "A_t0_Corner_plot.png" + outfname = "log_A_t0_Corner_plot.png" if os.path.exists(outfname): os.unlink(outfname) - self.gpresult.corner_plot("A", "t0", save=True) + self.gpresult.corner_plot("log_A", "t0", save=True) assert os.path.exists(outfname) os.unlink(outfname) @@ -402,6 +386,6 @@ def test_corner_plot_labels_and_fname(self): outfname = "blabla.png" if os.path.exists(outfname): os.unlink(outfname) - self.gpresult.corner_plot("A", "t0", axis=[0, 0.5, 0, 5], save=True, filename=outfname) + self.gpresult.corner_plot("log_A", "t0", axis=[0, 0.5, 0, 5], save=True, filename=outfname) assert os.path.exists(outfname) os.unlink(outfname) From cb0a0f555c210eebfa43254f0bca417a01beab2b Mon Sep 17 00:00:00 2001 From: Gaurav17Joshi Date: Fri, 25 Aug 2023 21:27:31 +0530 Subject: [PATCH 41/50] Changed Function Names --- stingray/modeling/gpmodeling.py | 14 +++++++------- stingray/modeling/tests/test_gpmodeling.py | 20 +++++++++++--------- 2 files changed, 18 insertions(+), 16 deletions(-) diff --git a/stingray/modeling/gpmodeling.py b/stingray/modeling/gpmodeling.py index 134ddf145..b3e16cacb 100644 --- a/stingray/modeling/gpmodeling.py +++ b/stingray/modeling/gpmodeling.py @@ -444,7 +444,7 @@ def prior_model(): return prior_model -def get_likelihood(params_list, kernel_type, mean_type, **kwargs): +def get_log_likelihood(params_list, kernel_type, mean_type, **kwargs): """ A log likelihood generator function based on given values. Makes a jaxns specific log likelihood function which takes in the @@ -559,9 +559,9 @@ def sample(self, prior_model=None, likelihood_model=None, max_samples=1e4): raise ImportError("Jaxns not installed! Can't sample!") self.prior_model = prior_model - self.likelihood_model = likelihood_model + self.log_likelihood_model = likelihood_model - NSmodel = Model(prior_model=self.prior_model, log_likelihood=self.likelihood_model) + NSmodel = Model(prior_model=self.prior_model, log_likelihood=self.log_likelihood_model) NSmodel.sanity_check(random.PRNGKey(10), S=100) self.Exact_ns = ExactNestedSampler(NSmodel, num_live_points=500, max_samples=max_samples) @@ -756,7 +756,7 @@ def weighted_posterior_plot( plt.savefig(filename) return plt - def corner_plot( + def comparison_plot( self, param1: str, param2: str, @@ -768,7 +768,7 @@ def corner_plot( filename=None, ): """ - Plots the corner plot between two given parameters + Plots the comparison plot between two given parameters Parameters ---------- @@ -834,7 +834,7 @@ def corner_plot( density=True, cmap="GnBu", ) - plt.title("Corner Plot of " + str(param1) + " and " + str(param2)) + plt.title("Comparison Plot of " + str(param1) + " and " + str(param2)) plt.xlabel(param2) plt.ylabel(param1) plt.colorbar() @@ -843,7 +843,7 @@ def corner_plot( if save: if filename is None: - plt.savefig(str(param1) + "_" + str(param2) + "_Corner_plot.png") + plt.savefig(str(param1) + "_" + str(param2) + "_Comparison_plot.png") else: plt.savefig(filename) diff --git a/stingray/modeling/tests/test_gpmodeling.py b/stingray/modeling/tests/test_gpmodeling.py index 1d7f5eecd..02743af34 100644 --- a/stingray/modeling/tests/test_gpmodeling.py +++ b/stingray/modeling/tests/test_gpmodeling.py @@ -23,7 +23,7 @@ _HAS_TINYGP = False from stingray.modeling.gpmodeling import get_kernel, get_mean, get_gp_params -from stingray.modeling.gpmodeling import get_prior, get_likelihood, GPResult +from stingray.modeling.gpmodeling import get_prior, get_log_likelihood, GPResult from stingray import Lightcurve try: @@ -270,7 +270,7 @@ def setup_class(self): } prior_model = get_prior(self.params_list, prior_dict) - likelihood_model = get_likelihood( + likelihood_model = get_log_likelihood( self.params_list, kernel_type="RN", mean_type="gaussian", @@ -368,24 +368,26 @@ def test_weighted_posterior_plot_labels_and_fname(self): assert os.path.exists(outfname) os.unlink(outfname) - def test_corner_plot(self): - self.gpresult.corner_plot("log_A", "t0") + def test_comparison_plot(self): + self.gpresult.comparison_plot("log_A", "t0") assert plt.fignum_exists(1) - def test_corner_plot_labels_and_fname_default(self): + def test_comparison_plot_labels_and_fname_default(self): clear_all_figs() - outfname = "log_A_t0_Corner_plot.png" + outfname = "log_A_t0_Comparison_plot.png" if os.path.exists(outfname): os.unlink(outfname) - self.gpresult.corner_plot("log_A", "t0", save=True) + self.gpresult.comparison_plot("log_A", "t0", save=True) assert os.path.exists(outfname) os.unlink(outfname) - def test_corner_plot_labels_and_fname(self): + def test_comparison_plot_labels_and_fname(self): clear_all_figs() outfname = "blabla.png" if os.path.exists(outfname): os.unlink(outfname) - self.gpresult.corner_plot("log_A", "t0", axis=[0, 0.5, 0, 5], save=True, filename=outfname) + self.gpresult.comparison_plot( + "log_A", "t0", axis=[0, 0.5, 0, 5], save=True, filename=outfname + ) assert os.path.exists(outfname) os.unlink(outfname) From a2c03c0148dc26403de92600fd66c85724338e0a Mon Sep 17 00:00:00 2001 From: Gaurav17Joshi Date: Fri, 25 Aug 2023 23:05:03 +0530 Subject: [PATCH 42/50] Docstring changes --- stingray/modeling/gpmodeling.py | 147 +++++++++++---------- stingray/modeling/tests/test_gpmodeling.py | 4 +- 2 files changed, 76 insertions(+), 75 deletions(-) diff --git a/stingray/modeling/gpmodeling.py b/stingray/modeling/gpmodeling.py index b3e16cacb..e52a71777 100644 --- a/stingray/modeling/gpmodeling.py +++ b/stingray/modeling/gpmodeling.py @@ -49,7 +49,8 @@ def get_kernel(kernel_type, kernel_params): ---------- kernel_type: string The type of kernel to be used for the Gaussian Process - To be selected from the kernels already implemented + To be selected from the kernels already implemented: + ["RN", "QPO", "QPO_plus_RN"] kernel_params: dict Dictionary containing the parameters for the kernel @@ -97,7 +98,9 @@ def get_mean(mean_type, mean_params): ---------- mean_type: string The type of mean to be used for the Gaussian Process - To be selected from the mean functions already implemented + To be selected from the mean functions already implemented: + ["gaussian", "exponential", "constant", "skew_gaussian", + "skew_exponential", "fred"] mean_params: dict Dictionary containing the parameters for the mean @@ -216,7 +219,7 @@ def _skew_gaussian(t, mean_params): sig1 = jnp.atleast_1d(mean_params["sig1"])[:, jnp.newaxis] sig2 = jnp.atleast_1d(mean_params["sig2"])[:, jnp.newaxis] - return jnp.sum( + y = jnp.sum( A * jnp.where( t > t0, @@ -225,6 +228,7 @@ def _skew_gaussian(t, mean_params): ), axis=0, ) + return y def _skew_exponential(t, mean_params): @@ -252,7 +256,7 @@ def _skew_exponential(t, mean_params): sig1 = jnp.atleast_1d(mean_params["sig1"])[:, jnp.newaxis] sig2 = jnp.atleast_1d(mean_params["sig2"])[:, jnp.newaxis] - return jnp.sum( + y = jnp.sum( A * jnp.where( t > t0, @@ -261,6 +265,7 @@ def _skew_exponential(t, mean_params): ), axis=0, ) + return y def _fred(t, mean_params): @@ -351,10 +356,13 @@ def get_gp_params(kernel_type, mean_type): Parameters ---------- kernel_type: string - The type of kernel to be used for the Gaussian Process model + The type of kernel to be used for the Gaussian Process model: + ["RN", "QPO", "QPO_plus_RN"] mean_type: string - The type of mean to be used for the Gaussian Process model + The type of mean to be used for the Gaussian Process model: + ["gaussian", "exponential", "constant", "skew_gaussian", + "skew_exponential", "fred"] Returns ------- @@ -402,15 +410,11 @@ def get_prior(params_list, prior_dict): Examples -------- A prior function for a Red Noise kernel and a Gaussian mean function - Obain the parameters list - if not can_sample: - pytest.skip("Jaxns not installed. Cannot make jaxns specific prior.") - if not tfp_available: - pytest.skip("Tensorflow probability required to make priors.") + # Obtain the parameters list params_list = get_gp_params("RN", "gaussian") - Make a prior dictionary using tensorflow_probability distributions + # Make a prior dictionary using tensorflow_probability distributions prior_dict = { "log_A": tfpd.Uniform(low = jnp.log(1e-1), high = jnp.log(2e+2)), "t0": tfpd.Uniform(low = 0.0 - 0.1, high = 1 + 0.1), @@ -444,7 +448,7 @@ def prior_model(): return prior_model -def get_log_likelihood(params_list, kernel_type, mean_type, **kwargs): +def get_log_likelihood(params_list, kernel_type, mean_type, times, counts, **kwargs): """ A log likelihood generator function based on given values. Makes a jaxns specific log likelihood function which takes in the @@ -462,23 +466,23 @@ def get_log_likelihood(params_list, kernel_type, mean_type, **kwargs): A dictionary of the priors of parameters to be used. kernel_type: - The type of kernel to be used in the model. + The type of kernel to be used in the model: + ["RN", "QPO", "QPO_plus_RN"] mean_type: - The type of mean to be used in the model. + The type of mean to be used in the model: + ["gaussian", "exponential", "constant", "skew_gaussian", + "skew_exponential", "fred"] + + times: np.array or jnp.array + The time array of the lightcurve - **kwargs: - The keyword arguments to be used in the log likelihood function. - **Note**: The keyword arguments Times and counts are necessary for - calculating the log likelihood. - Times: np.array or jnp.array - The time array of the lightcurve - counts: np.array or jnp.array - The photon counts array of the lightcurve + counts: np.array or jnp.array + The photon counts array of the lightcurve Returns ------- - The jaxns specific log likelihood function. + The Jaxns specific log likelihood function. """ if not jax_avail: @@ -497,8 +501,8 @@ def likelihood_model(*args): dict[params] = args[i] kernel = get_kernel(kernel_type=kernel_type, kernel_params=dict) mean = get_mean(mean_type=mean_type, mean_params=dict) - gp = GaussianProcess(kernel, kwargs["Times"], mean_value=mean(kwargs["Times"])) - return gp.log_probability(kwargs["counts"]) + gp = GaussianProcess(kernel, times, mean_value=mean(times)) + return gp.log_probability(counts) return likelihood_model @@ -513,21 +517,13 @@ class GPResult: lc: Stingray.Lightcurve object The lightcurve on which the bayesian inference is to be done - Other Parameters - ---------------- - time : class: np.array - The array containing the times of the lightcurve - - counts : class: np.array - The array containing the photon counts of the lightcurve - """ - def __init__(self, Lc: Lightcurve) -> None: - self.lc = Lc - self.time = Lc.time - self.counts = Lc.counts - self.Result = None + def __init__(self, lc: Lightcurve) -> None: + self.lc = lc + self.time = lc.time + self.counts = lc.counts + self.result = None def sample(self, prior_model=None, likelihood_model=None, max_samples=1e4): """ @@ -537,18 +533,23 @@ def sample(self, prior_model=None, likelihood_model=None, max_samples=1e4): Parameters ---------- prior_model: jaxns.prior.PriorModelType object - A prior generator object + A prior generator object. + Can be made using the get_prior function or can use your own jaxns + compatible prior function. likelihood_model: jaxns.types.LikelihoodType object A likelihood fucntion which takes in the arguments of the prior - model and returns the loglikelihood of the model + model and returns the loglikelihood of the model. + Can be made using the get_log_likelihood function or can use your own + log_likelihood function with same order of arguments as the prior_model. + max_samples: int, default 1e4 The maximum number of samples to be taken by the nested sampler Returns ---------- - Results: jaxns.results.NestedSamplerResults object + results: jaxns.results.NestedSamplerResults object The results of the nested sampling process """ @@ -564,49 +565,49 @@ def sample(self, prior_model=None, likelihood_model=None, max_samples=1e4): NSmodel = Model(prior_model=self.prior_model, log_likelihood=self.log_likelihood_model) NSmodel.sanity_check(random.PRNGKey(10), S=100) - self.Exact_ns = ExactNestedSampler(NSmodel, num_live_points=500, max_samples=max_samples) - Termination_reason, State = self.Exact_ns( + self.exact_ns = ExactNestedSampler(NSmodel, num_live_points=500, max_samples=max_samples) + termination_reason, State = self.exact_ns( random.PRNGKey(42), term_cond=TerminationCondition(live_evidence_frac=1e-4) ) - self.Results = self.Exact_ns.to_results(State, Termination_reason) + self.results = self.exact_ns.to_results(State, termination_reason) print("Simulation Complete") def get_evidence(self): """ Returns the log evidence of the model """ - return self.Results.log_Z_mean + return self.results.log_Z_mean def print_summary(self): """ Prints a summary table for the model parameters """ - self.Exact_ns.summary(self.Results) + self.exact_ns.summary(self.results) def plot_diagnostics(self): """ Plots the diagnostic plots for the sampling process """ - self.Exact_ns.plot_diagnostics(self.Results) + self.exact_ns.plot_diagnostics(self.results) def plot_cornerplot(self): """ Plots the corner plot for the sampled hyperparameters """ - self.Exact_ns.plot_cornerplot(self.Results) + self.exact_ns.plot_cornerplot(self.results) def get_parameters_names(self): """ Returns the names of the parameters """ - return sorted(self.Results.samples.keys()) + return sorted(self.results.samples.keys()) def get_max_posterior_parameters(self): """ Returns the optimal parameters for the model based on the NUTS sampling """ - max_post_idx = jnp.argmax(self.Results.log_posterior_density) - map_points = jax.tree_map(lambda x: x[max_post_idx], self.Results.samples) + max_post_idx = jnp.argmax(self.results.log_posterior_density) + map_points = jax.tree_map(lambda x: x[max_post_idx], self.results.samples) return map_points @@ -614,8 +615,8 @@ def get_max_likelihood_parameters(self): """ Retruns the maximum likelihood parameters """ - max_like_idx = jnp.argmax(self.Results.log_L_samples) - max_like_points = jax.tree_map(lambda x: x[max_like_idx], self.Results.samples) + max_like_idx = jnp.argmax(self.results.log_L_samples) + max_like_points = jax.tree_map(lambda x: x[max_like_idx], self.results.samples) return max_like_points @@ -631,7 +632,7 @@ def posterior_plot(self, name: str, n=0, axis=None, save=False, filename=None): used in the prior_function n : int, default 0 - The index of the parameter to be plotted. + The index of the parameter to be plotted (for multi component parameters). For multivariate parameters, the index of the specific parameter to be plotted. axis : list, tuple, string, default ``None`` @@ -651,13 +652,13 @@ def posterior_plot(self, name: str, n=0, axis=None, save=False, filename=None): Reference to plot, call ``show()`` to display it """ - nsamples = self.Results.total_num_samples - samples = self.Results.samples[name].reshape((nsamples, -1))[:, n] + nsamples = self.results.total_num_samples + samples = self.results.samples[name].reshape((nsamples, -1))[:, n] plt.hist( samples, bins="auto", density=True, alpha=1.0, label=name, fc="None", edgecolor="black" ) - mean1 = jnp.mean(self.Results.samples[name]) - std1 = jnp.std(self.Results.samples[name]) + mean1 = jnp.mean(self.results.samples[name]) + std1 = jnp.std(self.results.samples[name]) plt.axvline(mean1, color="red", linestyle="dashed", label="mean") plt.axvline(mean1 + std1, color="green", linestyle="dotted") plt.axvline(mean1 - std1, linestyle="dotted", color="green") @@ -690,7 +691,7 @@ def weighted_posterior_plot( used in the prior_function n : int, default 0 - The index of the parameter to be plotted. + The index of the parameter to be plotted (for multi component parameters). For multivariate parameters, the index of the specific parameter to be plotted. key: jax.random.PRNGKey, default ``random.PRNGKey(1234)`` @@ -715,17 +716,17 @@ def weighted_posterior_plot( if rkey is None: rkey = random.PRNGKey(1234) - nsamples = self.Results.total_num_samples - log_p = self.Results.log_dp_mean - samples = self.Results.samples[name].reshape((nsamples, -1))[:, n] + nsamples = self.results.total_num_samples + log_p = self.results.log_dp_mean + samples = self.results.samples[name].reshape((nsamples, -1))[:, n] weights = jnp.where(jnp.isfinite(samples), jnp.exp(log_p), 0.0) log_weights = jnp.where(jnp.isfinite(samples), log_p, -jnp.inf) samples_resampled = resample( - rkey, samples, log_weights, S=max(10, int(self.Results.ESS)), replace=True + rkey, samples, log_weights, S=max(10, int(self.results.ESS)), replace=True ) - nbins = max(10, int(jnp.sqrt(self.Results.ESS)) + 1) + nbins = max(10, int(jnp.sqrt(self.results.ESS)) + 1) binsx = jnp.linspace(*jnp.percentile(samples_resampled, jnp.asarray([0, 100])), 2 * nbins) plt.hist( @@ -783,11 +784,11 @@ def comparison_plot( used in the prior_function n1 : int, default 0 - The index of the first parameter to be plotted. + The index of the first parameter to be plotted (for multi component parameters). For multivariate parameters, the index of the specific parameter to be plotted. n2 : int, default 0 - The index of the second parameter to be plotted. + The index of the second parameter to be plotted (for multi component parameters). For multivariate parameters, the index of the specific parameter to be plotted. key: jax.random.PRNGKey, default ``random.PRNGKey(1234)`` @@ -812,19 +813,19 @@ def comparison_plot( if rkey is None: rkey = random.PRNGKey(1234) - nsamples = self.Results.total_num_samples - log_p = self.Results.log_dp_mean - samples1 = self.Results.samples[param1].reshape((nsamples, -1))[:, n1] - samples2 = self.Results.samples[param2].reshape((nsamples, -1))[:, n2] + nsamples = self.results.total_num_samples + log_p = self.results.log_dp_mean + samples1 = self.results.samples[param1].reshape((nsamples, -1))[:, n1] + samples2 = self.results.samples[param2].reshape((nsamples, -1))[:, n2] log_weights = jnp.where(jnp.isfinite(samples2), log_p, -jnp.inf) - nbins = max(10, int(jnp.sqrt(self.Results.ESS)) + 1) + nbins = max(10, int(jnp.sqrt(self.results.ESS)) + 1) samples_resampled = resample( rkey, jnp.stack([samples1, samples2], axis=-1), log_weights, - S=max(10, int(self.Results.ESS)), + S=max(10, int(self.results.ESS)), replace=True, ) plt.hist2d( diff --git a/stingray/modeling/tests/test_gpmodeling.py b/stingray/modeling/tests/test_gpmodeling.py index 02743af34..722467ce3 100644 --- a/stingray/modeling/tests/test_gpmodeling.py +++ b/stingray/modeling/tests/test_gpmodeling.py @@ -274,7 +274,7 @@ def setup_class(self): self.params_list, kernel_type="RN", mean_type="gaussian", - Times=self.Times, + times=self.Times, counts=self.counts, ) @@ -294,7 +294,7 @@ def setup_class(self): def test_sample(self): for key in self.params_list: - assert (self.Results.samples[key]).all() == (self.gpresult.Results.samples[key]).all() + assert (self.Results.samples[key]).all() == (self.gpresult.results.samples[key]).all() def test_get_evidence(self): assert self.Results.log_Z_mean == self.gpresult.get_evidence() From eeb241afd7bb3b0e02266781df61652016cf6657 Mon Sep 17 00:00:00 2001 From: Gaurav17Joshi Date: Mon, 28 Aug 2023 21:41:12 +0530 Subject: [PATCH 43/50] Changelog Changed --- docs/changes/739.feature.rst | 20 +++++++++++++++++++- 1 file changed, 19 insertions(+), 1 deletion(-) diff --git a/docs/changes/739.feature.rst b/docs/changes/739.feature.rst index b1661b37c..ff7cecff8 100644 --- a/docs/changes/739.feature.rst +++ b/docs/changes/739.feature.rst @@ -1 +1,19 @@ -A feature dealing with Gaussian Processes for Qpo analysis \ No newline at end of file +This is a JAX implementation of the GP tool by `Hubener et al `_ +for QPO detection and parameter analysis. + +This feature makes use of tinygp library for Gaussian Processes implementation, and jaxns for nested sampling, +and is kept in the stingray.modeling.gpmodelling module. + +Main features of the module are: + +- get_prior: This function makes the prior function for a specified prior dictionary. +- get_likelihood: This function makes the log_likelihood function for the given Kernel, Mean model and lightcurve. +- GPResult class: The class which takes a Lightcurve, and performs Nested Sampling for a given prior and likelihood. + +The additional Dependencies for the code +- jax +- tinygp +- jaxns +- etils +- tensorflow_probability +- typing_extensions \ No newline at end of file From c73620b4b778832af7ade1fc333b3a4c85a5eda1 Mon Sep 17 00:00:00 2001 From: Gaurav17Joshi Date: Tue, 29 Aug 2023 01:26:57 +0530 Subject: [PATCH 44/50] Improved get_mean docs --- stingray/modeling/gpmodeling.py | 185 ++++++++++++++++++++------------ 1 file changed, 115 insertions(+), 70 deletions(-) diff --git a/stingray/modeling/gpmodeling.py b/stingray/modeling/gpmodeling.py index e52a71777..8780e6af6 100644 --- a/stingray/modeling/gpmodeling.py +++ b/stingray/modeling/gpmodeling.py @@ -92,7 +92,7 @@ def get_kernel(kernel_type, kernel_params): def get_mean(mean_type, mean_params): """ - Function for producing the mean for the Gaussian Process. + Function for producing the mean function for the Gaussian Process. Parameters ---------- @@ -106,118 +106,153 @@ def get_mean(mean_type, mean_params): Dictionary containing the parameters for the mean Should contain the parameters for the selected mean + Returns + ------- + A function which takes in the time coordinates and returns the mean values. + + Examples + -------- + Unimodal Gaussian Mean Function: + mean_params = {"A": 3.0, "t0": 0.2, "sig1": 0.1, "sig2": 0.4} + mean = get_mean("gaussian", mean_params) + + Multimodal Gaussian Mean Function: + mean_params = {"A": jnp.array([3.0, 4.0]), "t0": jnp.array([0.2, 1]), + "sig1": jnp.array([0.1, 0.4]), "sig2": jnp.array([0.4, 0.1])} + mean = get_mean("gaussian", mean_params) + """ if not jax_avail: raise ImportError("Jax is required") if mean_type == "gaussian": - mean = functools.partial(_gaussian, mean_params=mean_params) + mean = functools.partial(_gaussian, params=mean_params) elif mean_type == "exponential": - mean = functools.partial(_exponential, mean_params=mean_params) + mean = functools.partial(_exponential, params=mean_params) elif mean_type == "constant": - mean = functools.partial(_constant, mean_params=mean_params) + mean = functools.partial(_constant, params=mean_params) elif mean_type == "skew_gaussian": - mean = functools.partial(_skew_gaussian, mean_params=mean_params) + mean = functools.partial(_skew_gaussian, params=mean_params) elif mean_type == "skew_exponential": - mean = functools.partial(_skew_exponential, mean_params=mean_params) + mean = functools.partial(_skew_exponential, params=mean_params) elif mean_type == "fred": - mean = functools.partial(_fred, mean_params=mean_params) + mean = functools.partial(_fred, params=mean_params) else: raise ValueError("Mean type not implemented") return mean -def _gaussian(t, mean_params): +def _gaussian(t, params): """A gaussian flare shape. Parameters ---------- t: jnp.ndarray The time coordinates. - A: jnp.int - Amplitude of the flare. - t0: - The location of the maximum. - sig1: - The width parameter for the gaussian. + + params: dict + The dictionary contating parameter values of the gaussian flare. + + The parameters for the gaussian flare are: + A: jnp.float / jnp.ndarray + Amplitude of the flare. + t0: jnp.float / jnp.ndarray + The location of the maximum. + sig1: jnp.float / jnp.ndarray + The width parameter for the gaussian. Returns ------- The y values for the gaussian flare. """ - A = jnp.atleast_1d(mean_params["A"])[:, jnp.newaxis] - t0 = jnp.atleast_1d(mean_params["t0"])[:, jnp.newaxis] - sig = jnp.atleast_1d(mean_params["sig"])[:, jnp.newaxis] + A = jnp.atleast_1d(params["A"])[:, jnp.newaxis] + t0 = jnp.atleast_1d(params["t0"])[:, jnp.newaxis] + sig = jnp.atleast_1d(params["sig"])[:, jnp.newaxis] return jnp.sum(A * jnp.exp(-((t - t0) ** 2) / (2 * (sig**2))), axis=0) -def _exponential(t, mean_params): +def _exponential(t, params): """An exponential flare shape. Parameters ---------- t: jnp.ndarray The time coordinates. - A: jnp.int - Amplitude of the flare. - t0: - The location of the maximum. - sig1: - The width parameter for the exponential. + + params: dict + The dictionary contating parameter values of the exponential flare. + + The parameters for the exponential flare are: + A: jnp.float / jnp.ndarray + Amplitude of the flare. + t0: jnp.float / jnp.ndarray + The location of the maximum. + sig1: jnp.float / jnp.ndarray + The width parameter for the exponential. Returns ------- The y values for exponential flare. """ - A = jnp.atleast_1d(mean_params["A"])[:, jnp.newaxis] - t0 = jnp.atleast_1d(mean_params["t0"])[:, jnp.newaxis] - sig = jnp.atleast_1d(mean_params["sig"])[:, jnp.newaxis] + A = jnp.atleast_1d(params["A"])[:, jnp.newaxis] + t0 = jnp.atleast_1d(params["t0"])[:, jnp.newaxis] + sig = jnp.atleast_1d(params["sig"])[:, jnp.newaxis] return jnp.sum(A * jnp.exp(-jnp.abs(t - t0) / (2 * (sig**2))), axis=0) -def _constant(t, mean_params): +def _constant(t, params): """A constant mean shape. Parameters ---------- t: jnp.ndarray The time coordinates. - A: jnp.int - Constant amplitude of the flare. + + params: dict + The dictionary contating parameter values of the constant flare. + + The parameters for the constant flare are: + A: jnp.float + Constant amplitude of the flare. Returns ------- The constant value. """ - return mean_params["A"] * jnp.ones_like(t) + return params["A"] * jnp.ones_like(t) -def _skew_gaussian(t, mean_params): +def _skew_gaussian(t, params): """A skew gaussian flare shape. Parameters ---------- t: jnp.ndarray The time coordinates. - A: jnp.int - Amplitude of the flare. - t0: - The location of the maximum. - sig1: - The width parameter for the rising edge. - sig2: - The width parameter for the falling edge. + + params: dict + The dictionary contating parameter values of the skew gaussian flare. + + The parameters for the skew gaussian flare are: + A: jnp.float / jnp.ndarray + Amplitude of the flare. + t0: jnp.float / jnp.ndarray + The location of the maximum. + sig1: jnp.float / jnp.ndarray + The width parameter for the rising edge. + sig2: jnp.float / jnp.ndarray + The width parameter for the falling edge. Returns ------- The y values for skew gaussian flare. """ - A = jnp.atleast_1d(mean_params["A"])[:, jnp.newaxis] - t0 = jnp.atleast_1d(mean_params["t0"])[:, jnp.newaxis] - sig1 = jnp.atleast_1d(mean_params["sig1"])[:, jnp.newaxis] - sig2 = jnp.atleast_1d(mean_params["sig2"])[:, jnp.newaxis] + A = jnp.atleast_1d(params["A"])[:, jnp.newaxis] + t0 = jnp.atleast_1d(params["t0"])[:, jnp.newaxis] + sig1 = jnp.atleast_1d(params["sig1"])[:, jnp.newaxis] + sig2 = jnp.atleast_1d(params["sig2"])[:, jnp.newaxis] y = jnp.sum( A @@ -231,30 +266,35 @@ def _skew_gaussian(t, mean_params): return y -def _skew_exponential(t, mean_params): +def _skew_exponential(t, params): """A skew exponential flare shape. Parameters ---------- t: jnp.ndarray The time coordinates. - A: jnp.int - Amplitude of the flare. - t0: - The location of the maximum. - sig1: - The width parameter for the rising edge. - sig2: - The width parameter for the falling edge. + + params: dict + The dictionary contating parameter values of the skew exponential flare. + + The parameters for the skew exponential flare are: + A: jnp.float / jnp.ndarray + Amplitude of the flare. + t0: jnp.float / jnp.ndarray + The location of the maximum. + sig1: jnp.float / jnp.ndarray + The width parameter for the rising edge. + sig2: jnp.float / jnp.ndarray + The width parameter for the falling edge. Returns ------- The y values for exponential flare. """ - A = jnp.atleast_1d(mean_params["A"])[:, jnp.newaxis] - t0 = jnp.atleast_1d(mean_params["t0"])[:, jnp.newaxis] - sig1 = jnp.atleast_1d(mean_params["sig1"])[:, jnp.newaxis] - sig2 = jnp.atleast_1d(mean_params["sig2"])[:, jnp.newaxis] + A = jnp.atleast_1d(params["A"])[:, jnp.newaxis] + t0 = jnp.atleast_1d(params["t0"])[:, jnp.newaxis] + sig1 = jnp.atleast_1d(params["sig1"])[:, jnp.newaxis] + sig2 = jnp.atleast_1d(params["sig2"])[:, jnp.newaxis] y = jnp.sum( A @@ -268,30 +308,35 @@ def _skew_exponential(t, mean_params): return y -def _fred(t, mean_params): +def _fred(t, params): """A fast rise exponential decay (FRED) flare shape. Parameters ---------- t: jnp.ndarray The time coordinates. - A: jnp.int - Amplitude of the flare. - t0: - The location of the maximum. - phi: - Symmetry parameter of the flare. - delta: - Offset parameter of the flare. + + params: dict + The dictionary contating parameter values of the FRED flare. + + The parameters for the FRED flare are: + A: jnp.float / jnp.ndarray + Amplitude of the flare. + t0: jnp.float / jnp.ndarray + The location of the maximum. + phi: jnp.float / jnp.ndarray + Symmetry parameter of the flare. + delta: jnp.float / jnp.ndarray + Offset parameter of the flare. Returns ------- The y values for exponential flare. """ - A = jnp.atleast_1d(mean_params["A"])[:, jnp.newaxis] - t0 = jnp.atleast_1d(mean_params["t0"])[:, jnp.newaxis] - phi = jnp.atleast_1d(mean_params["phi"])[:, jnp.newaxis] - delta = jnp.atleast_1d(mean_params["delta"])[:, jnp.newaxis] + A = jnp.atleast_1d(params["A"])[:, jnp.newaxis] + t0 = jnp.atleast_1d(params["t0"])[:, jnp.newaxis] + phi = jnp.atleast_1d(params["phi"])[:, jnp.newaxis] + delta = jnp.atleast_1d(params["delta"])[:, jnp.newaxis] return jnp.sum( A * jnp.exp(-phi * ((t + delta) / t0 + t0 / (t + delta))) * jnp.exp(2 * phi), axis=0 From 2d0ba466ec7b29130f268bffd30ecf063a7d82e0 Mon Sep 17 00:00:00 2001 From: Gaurav17Joshi Date: Wed, 30 Aug 2023 16:27:00 +0530 Subject: [PATCH 45/50] Docstrings Updated --- stingray/modeling/gpmodeling.py | 20 ++++++++++++-------- stingray/modeling/tests/test_gpmodeling.py | 4 +++- 2 files changed, 15 insertions(+), 9 deletions(-) diff --git a/stingray/modeling/gpmodeling.py b/stingray/modeling/gpmodeling.py index 8780e6af6..2f7e799b0 100644 --- a/stingray/modeling/gpmodeling.py +++ b/stingray/modeling/gpmodeling.py @@ -113,13 +113,13 @@ def get_mean(mean_type, mean_params): Examples -------- Unimodal Gaussian Mean Function: - mean_params = {"A": 3.0, "t0": 0.2, "sig1": 0.1, "sig2": 0.4} + mean_params = {"A": 3.0, "t0": 0.2, "sig": 0.1} mean = get_mean("gaussian", mean_params) - Multimodal Gaussian Mean Function: + Multimodal Skew Gaussian Mean Function: mean_params = {"A": jnp.array([3.0, 4.0]), "t0": jnp.array([0.2, 1]), "sig1": jnp.array([0.1, 0.4]), "sig2": jnp.array([0.4, 0.1])} - mean = get_mean("gaussian", mean_params) + mean = get_mean("skew_gaussian", mean_params) """ if not jax_avail: @@ -338,10 +338,12 @@ def _fred(t, params): phi = jnp.atleast_1d(params["phi"])[:, jnp.newaxis] delta = jnp.atleast_1d(params["delta"])[:, jnp.newaxis] - return jnp.sum( + y = jnp.sum( A * jnp.exp(-phi * ((t + delta) / t0 + t0 / (t + delta))) * jnp.exp(2 * phi), axis=0 ) + return y + def _get_kernel_params(kernel_type): """ @@ -485,8 +487,10 @@ def prior_model(): for i in params_list: if isinstance(prior_dict[i], tfpd.Distribution): parameter = yield Prior(prior_dict[i], name=i) - else: + elif isinstance(prior_dict[i], Prior): parameter = yield prior_dict[i] + else: + raise ValueError("Prior should be a tfpd distribution or a jaxns prior.") prior_list.append(parameter) return tuple(prior_list) @@ -607,10 +611,10 @@ def sample(self, prior_model=None, likelihood_model=None, max_samples=1e4): self.prior_model = prior_model self.log_likelihood_model = likelihood_model - NSmodel = Model(prior_model=self.prior_model, log_likelihood=self.log_likelihood_model) - NSmodel.sanity_check(random.PRNGKey(10), S=100) + nsmodel = Model(prior_model=self.prior_model, log_likelihood=self.log_likelihood_model) + nsmodel.sanity_check(random.PRNGKey(10), S=100) - self.exact_ns = ExactNestedSampler(NSmodel, num_live_points=500, max_samples=max_samples) + self.exact_ns = ExactNestedSampler(nsmodel, num_live_points=500, max_samples=max_samples) termination_reason, State = self.exact_ns( random.PRNGKey(42), term_cond=TerminationCondition(live_evidence_frac=1e-4) ) diff --git a/stingray/modeling/tests/test_gpmodeling.py b/stingray/modeling/tests/test_gpmodeling.py index 722467ce3..dca1e48ca 100644 --- a/stingray/modeling/tests/test_gpmodeling.py +++ b/stingray/modeling/tests/test_gpmodeling.py @@ -262,7 +262,9 @@ def setup_class(self): # The prior dictionary, with suitable tfpd prior distributions prior_dict = { - "log_A": tfpd.Uniform(low=jnp.log(0.1 * span), high=jnp.log(2 * span)), + "log_A": Prior( + tfpd.Uniform(low=jnp.log(0.1 * span), high=jnp.log(2 * span)), name="log_A" + ), "t0": tfpd.Uniform(low=self.Times[0] - 0.1 * T, high=self.Times[-1] + 0.1 * T), "log_sig": tfpd.Uniform(low=jnp.log(0.5 * 1 / f), high=jnp.log(2 * T)), "log_arn": tfpd.Uniform(low=jnp.log(0.1 * span), high=jnp.log(2 * span)), From c3778a3ddb9f55a523dde76d04911eae528b7d2c Mon Sep 17 00:00:00 2001 From: Gaurav17Joshi Date: Tue, 5 Sep 2023 19:28:55 +0530 Subject: [PATCH 46/50] Small doc change --- stingray/modeling/gpmodeling.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/stingray/modeling/gpmodeling.py b/stingray/modeling/gpmodeling.py index 2f7e799b0..5c953604a 100644 --- a/stingray/modeling/gpmodeling.py +++ b/stingray/modeling/gpmodeling.py @@ -615,10 +615,10 @@ def sample(self, prior_model=None, likelihood_model=None, max_samples=1e4): nsmodel.sanity_check(random.PRNGKey(10), S=100) self.exact_ns = ExactNestedSampler(nsmodel, num_live_points=500, max_samples=max_samples) - termination_reason, State = self.exact_ns( + termination_reason, state = self.exact_ns( random.PRNGKey(42), term_cond=TerminationCondition(live_evidence_frac=1e-4) ) - self.results = self.exact_ns.to_results(State, termination_reason) + self.results = self.exact_ns.to_results(state, termination_reason) print("Simulation Complete") def get_evidence(self): From d3447f41627754b0ee118a18be4d2342fb177922 Mon Sep 17 00:00:00 2001 From: Daniela Huppenkothen Date: Fri, 29 Sep 2023 11:17:44 +0200 Subject: [PATCH 47/50] Fixed usage of dict, added functions to __all__ --- stingray/modeling/gpmodeling.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/stingray/modeling/gpmodeling.py b/stingray/modeling/gpmodeling.py index 5c953604a..55a20e067 100644 --- a/stingray/modeling/gpmodeling.py +++ b/stingray/modeling/gpmodeling.py @@ -37,7 +37,8 @@ tfp_available = False -__all__ = ["GPResult"] +__all__ = ["get_kernel", "get_mean", "get_prior", + "get_log_likelihood", "GPResult", "get_gp_params"] def get_kernel(kernel_type, kernel_params): @@ -542,14 +543,14 @@ def get_log_likelihood(params_list, kernel_type, mean_type, times, counts, **kwa @jit def likelihood_model(*args): - dict = {} + param_dict = {} for i, params in enumerate(params_list): if params[0:4] == "log_": - dict[params[4:]] = jnp.exp(args[i]) + param_dict[params[4:]] = jnp.exp(args[i]) else: - dict[params] = args[i] - kernel = get_kernel(kernel_type=kernel_type, kernel_params=dict) - mean = get_mean(mean_type=mean_type, mean_params=dict) + param_dict[params] = args[i] + kernel = get_kernel(kernel_type=kernel_type, kernel_params=param_dict) + mean = get_mean(mean_type=mean_type, mean_params=param_dict) gp = GaussianProcess(kernel, times, mean_value=mean(times)) return gp.log_probability(counts) From 1fc9f84665e36e770b738dea2d8dd36dacdc38f7 Mon Sep 17 00:00:00 2001 From: Daniela Huppenkothen Date: Fri, 29 Sep 2023 11:23:07 +0200 Subject: [PATCH 48/50] Added num_live_points as parameter --- stingray/modeling/gpmodeling.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/stingray/modeling/gpmodeling.py b/stingray/modeling/gpmodeling.py index 55a20e067..85c15ac72 100644 --- a/stingray/modeling/gpmodeling.py +++ b/stingray/modeling/gpmodeling.py @@ -575,7 +575,8 @@ def __init__(self, lc: Lightcurve) -> None: self.counts = lc.counts self.result = None - def sample(self, prior_model=None, likelihood_model=None, max_samples=1e4): + def sample(self, prior_model=None, likelihood_model=None, max_samples=1e4, + num_live_points=500): """ Makes a Jaxns nested sampler over the Gaussian Process, given the prior and likelihood model @@ -597,6 +598,9 @@ def sample(self, prior_model=None, likelihood_model=None, max_samples=1e4): max_samples: int, default 1e4 The maximum number of samples to be taken by the nested sampler + num_live_points : int, default 500 + The number of live points to use in the nested sampling + Returns ---------- results: jaxns.results.NestedSamplerResults object @@ -615,7 +619,7 @@ def sample(self, prior_model=None, likelihood_model=None, max_samples=1e4): nsmodel = Model(prior_model=self.prior_model, log_likelihood=self.log_likelihood_model) nsmodel.sanity_check(random.PRNGKey(10), S=100) - self.exact_ns = ExactNestedSampler(nsmodel, num_live_points=500, max_samples=max_samples) + self.exact_ns = ExactNestedSampler(nsmodel, num_live_points=num_live_points, max_samples=max_samples) termination_reason, state = self.exact_ns( random.PRNGKey(42), term_cond=TerminationCondition(live_evidence_frac=1e-4) ) From 78e13f4332f1934769865e9dc3c4113838493106 Mon Sep 17 00:00:00 2001 From: Daniela Huppenkothen Date: Fri, 29 Sep 2023 11:24:08 +0200 Subject: [PATCH 49/50] Fixed typo in changelog --- docs/changes/739.feature.rst | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/changes/739.feature.rst b/docs/changes/739.feature.rst index ff7cecff8..e42f8862a 100644 --- a/docs/changes/739.feature.rst +++ b/docs/changes/739.feature.rst @@ -2,7 +2,7 @@ This is a JAX implementation of the GP tool by `Hubener et al Date: Fri, 29 Sep 2023 11:28:19 +0200 Subject: [PATCH 50/50] Added installation instructions for GP Modeling --- docs/index.rst | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/docs/index.rst b/docs/index.rst index b250d1297..d2db9152b 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -113,6 +113,20 @@ To install all required and recommended dependencies in a recent installation, y $ pip install astropy scipy matplotlib numpy h5py tqdm numba pint-pulsar emcee corner statsmodels pyfftw tbb +For the Gaussian Process modeling in `stingray.modeling.gpmodeling`, you'll need the following extra packages + ++ jax ++ jaxns ++ tensorflow ++ tensorflow-probability ++ tinygp ++ etils ++ typing_extensions + +Most of these are installed via ``pip``, but if you have an Nvidia GPU available, you'll want to take special care +following the installation instructions for jax and tensorflow(-probability) in order to enable GPU support and +take advantage of those speed-ups. + For development work, you will need the following extra libraries: + pytest