From 92c52c8e19eab897fdae1245c0fc28956e318cc4 Mon Sep 17 00:00:00 2001 From: Gaurav17Joshi Date: Tue, 29 Aug 2023 01:26:57 +0530 Subject: [PATCH] Improved get_mean docs --- stingray/modeling/gpmodeling.py | 185 ++++++++++++++++++++------------ 1 file changed, 115 insertions(+), 70 deletions(-) diff --git a/stingray/modeling/gpmodeling.py b/stingray/modeling/gpmodeling.py index e52a71777..8780e6af6 100644 --- a/stingray/modeling/gpmodeling.py +++ b/stingray/modeling/gpmodeling.py @@ -92,7 +92,7 @@ def get_kernel(kernel_type, kernel_params): def get_mean(mean_type, mean_params): """ - Function for producing the mean for the Gaussian Process. + Function for producing the mean function for the Gaussian Process. Parameters ---------- @@ -106,118 +106,153 @@ def get_mean(mean_type, mean_params): Dictionary containing the parameters for the mean Should contain the parameters for the selected mean + Returns + ------- + A function which takes in the time coordinates and returns the mean values. + + Examples + -------- + Unimodal Gaussian Mean Function: + mean_params = {"A": 3.0, "t0": 0.2, "sig1": 0.1, "sig2": 0.4} + mean = get_mean("gaussian", mean_params) + + Multimodal Gaussian Mean Function: + mean_params = {"A": jnp.array([3.0, 4.0]), "t0": jnp.array([0.2, 1]), + "sig1": jnp.array([0.1, 0.4]), "sig2": jnp.array([0.4, 0.1])} + mean = get_mean("gaussian", mean_params) + """ if not jax_avail: raise ImportError("Jax is required") if mean_type == "gaussian": - mean = functools.partial(_gaussian, mean_params=mean_params) + mean = functools.partial(_gaussian, params=mean_params) elif mean_type == "exponential": - mean = functools.partial(_exponential, mean_params=mean_params) + mean = functools.partial(_exponential, params=mean_params) elif mean_type == "constant": - mean = functools.partial(_constant, mean_params=mean_params) + mean = functools.partial(_constant, params=mean_params) elif mean_type == "skew_gaussian": - mean = functools.partial(_skew_gaussian, mean_params=mean_params) + mean = functools.partial(_skew_gaussian, params=mean_params) elif mean_type == "skew_exponential": - mean = functools.partial(_skew_exponential, mean_params=mean_params) + mean = functools.partial(_skew_exponential, params=mean_params) elif mean_type == "fred": - mean = functools.partial(_fred, mean_params=mean_params) + mean = functools.partial(_fred, params=mean_params) else: raise ValueError("Mean type not implemented") return mean -def _gaussian(t, mean_params): +def _gaussian(t, params): """A gaussian flare shape. Parameters ---------- t: jnp.ndarray The time coordinates. - A: jnp.int - Amplitude of the flare. - t0: - The location of the maximum. - sig1: - The width parameter for the gaussian. + + params: dict + The dictionary contating parameter values of the gaussian flare. + + The parameters for the gaussian flare are: + A: jnp.float / jnp.ndarray + Amplitude of the flare. + t0: jnp.float / jnp.ndarray + The location of the maximum. + sig1: jnp.float / jnp.ndarray + The width parameter for the gaussian. Returns ------- The y values for the gaussian flare. """ - A = jnp.atleast_1d(mean_params["A"])[:, jnp.newaxis] - t0 = jnp.atleast_1d(mean_params["t0"])[:, jnp.newaxis] - sig = jnp.atleast_1d(mean_params["sig"])[:, jnp.newaxis] + A = jnp.atleast_1d(params["A"])[:, jnp.newaxis] + t0 = jnp.atleast_1d(params["t0"])[:, jnp.newaxis] + sig = jnp.atleast_1d(params["sig"])[:, jnp.newaxis] return jnp.sum(A * jnp.exp(-((t - t0) ** 2) / (2 * (sig**2))), axis=0) -def _exponential(t, mean_params): +def _exponential(t, params): """An exponential flare shape. Parameters ---------- t: jnp.ndarray The time coordinates. - A: jnp.int - Amplitude of the flare. - t0: - The location of the maximum. - sig1: - The width parameter for the exponential. + + params: dict + The dictionary contating parameter values of the exponential flare. + + The parameters for the exponential flare are: + A: jnp.float / jnp.ndarray + Amplitude of the flare. + t0: jnp.float / jnp.ndarray + The location of the maximum. + sig1: jnp.float / jnp.ndarray + The width parameter for the exponential. Returns ------- The y values for exponential flare. """ - A = jnp.atleast_1d(mean_params["A"])[:, jnp.newaxis] - t0 = jnp.atleast_1d(mean_params["t0"])[:, jnp.newaxis] - sig = jnp.atleast_1d(mean_params["sig"])[:, jnp.newaxis] + A = jnp.atleast_1d(params["A"])[:, jnp.newaxis] + t0 = jnp.atleast_1d(params["t0"])[:, jnp.newaxis] + sig = jnp.atleast_1d(params["sig"])[:, jnp.newaxis] return jnp.sum(A * jnp.exp(-jnp.abs(t - t0) / (2 * (sig**2))), axis=0) -def _constant(t, mean_params): +def _constant(t, params): """A constant mean shape. Parameters ---------- t: jnp.ndarray The time coordinates. - A: jnp.int - Constant amplitude of the flare. + + params: dict + The dictionary contating parameter values of the constant flare. + + The parameters for the constant flare are: + A: jnp.float + Constant amplitude of the flare. Returns ------- The constant value. """ - return mean_params["A"] * jnp.ones_like(t) + return params["A"] * jnp.ones_like(t) -def _skew_gaussian(t, mean_params): +def _skew_gaussian(t, params): """A skew gaussian flare shape. Parameters ---------- t: jnp.ndarray The time coordinates. - A: jnp.int - Amplitude of the flare. - t0: - The location of the maximum. - sig1: - The width parameter for the rising edge. - sig2: - The width parameter for the falling edge. + + params: dict + The dictionary contating parameter values of the skew gaussian flare. + + The parameters for the skew gaussian flare are: + A: jnp.float / jnp.ndarray + Amplitude of the flare. + t0: jnp.float / jnp.ndarray + The location of the maximum. + sig1: jnp.float / jnp.ndarray + The width parameter for the rising edge. + sig2: jnp.float / jnp.ndarray + The width parameter for the falling edge. Returns ------- The y values for skew gaussian flare. """ - A = jnp.atleast_1d(mean_params["A"])[:, jnp.newaxis] - t0 = jnp.atleast_1d(mean_params["t0"])[:, jnp.newaxis] - sig1 = jnp.atleast_1d(mean_params["sig1"])[:, jnp.newaxis] - sig2 = jnp.atleast_1d(mean_params["sig2"])[:, jnp.newaxis] + A = jnp.atleast_1d(params["A"])[:, jnp.newaxis] + t0 = jnp.atleast_1d(params["t0"])[:, jnp.newaxis] + sig1 = jnp.atleast_1d(params["sig1"])[:, jnp.newaxis] + sig2 = jnp.atleast_1d(params["sig2"])[:, jnp.newaxis] y = jnp.sum( A @@ -231,30 +266,35 @@ def _skew_gaussian(t, mean_params): return y -def _skew_exponential(t, mean_params): +def _skew_exponential(t, params): """A skew exponential flare shape. Parameters ---------- t: jnp.ndarray The time coordinates. - A: jnp.int - Amplitude of the flare. - t0: - The location of the maximum. - sig1: - The width parameter for the rising edge. - sig2: - The width parameter for the falling edge. + + params: dict + The dictionary contating parameter values of the skew exponential flare. + + The parameters for the skew exponential flare are: + A: jnp.float / jnp.ndarray + Amplitude of the flare. + t0: jnp.float / jnp.ndarray + The location of the maximum. + sig1: jnp.float / jnp.ndarray + The width parameter for the rising edge. + sig2: jnp.float / jnp.ndarray + The width parameter for the falling edge. Returns ------- The y values for exponential flare. """ - A = jnp.atleast_1d(mean_params["A"])[:, jnp.newaxis] - t0 = jnp.atleast_1d(mean_params["t0"])[:, jnp.newaxis] - sig1 = jnp.atleast_1d(mean_params["sig1"])[:, jnp.newaxis] - sig2 = jnp.atleast_1d(mean_params["sig2"])[:, jnp.newaxis] + A = jnp.atleast_1d(params["A"])[:, jnp.newaxis] + t0 = jnp.atleast_1d(params["t0"])[:, jnp.newaxis] + sig1 = jnp.atleast_1d(params["sig1"])[:, jnp.newaxis] + sig2 = jnp.atleast_1d(params["sig2"])[:, jnp.newaxis] y = jnp.sum( A @@ -268,30 +308,35 @@ def _skew_exponential(t, mean_params): return y -def _fred(t, mean_params): +def _fred(t, params): """A fast rise exponential decay (FRED) flare shape. Parameters ---------- t: jnp.ndarray The time coordinates. - A: jnp.int - Amplitude of the flare. - t0: - The location of the maximum. - phi: - Symmetry parameter of the flare. - delta: - Offset parameter of the flare. + + params: dict + The dictionary contating parameter values of the FRED flare. + + The parameters for the FRED flare are: + A: jnp.float / jnp.ndarray + Amplitude of the flare. + t0: jnp.float / jnp.ndarray + The location of the maximum. + phi: jnp.float / jnp.ndarray + Symmetry parameter of the flare. + delta: jnp.float / jnp.ndarray + Offset parameter of the flare. Returns ------- The y values for exponential flare. """ - A = jnp.atleast_1d(mean_params["A"])[:, jnp.newaxis] - t0 = jnp.atleast_1d(mean_params["t0"])[:, jnp.newaxis] - phi = jnp.atleast_1d(mean_params["phi"])[:, jnp.newaxis] - delta = jnp.atleast_1d(mean_params["delta"])[:, jnp.newaxis] + A = jnp.atleast_1d(params["A"])[:, jnp.newaxis] + t0 = jnp.atleast_1d(params["t0"])[:, jnp.newaxis] + phi = jnp.atleast_1d(params["phi"])[:, jnp.newaxis] + delta = jnp.atleast_1d(params["delta"])[:, jnp.newaxis] return jnp.sum( A * jnp.exp(-phi * ((t + delta) / t0 + t0 / (t + delta))) * jnp.exp(2 * phi), axis=0