Skip to content

Commit

Permalink
Improved get_mean docs
Browse files Browse the repository at this point in the history
  • Loading branch information
Gaurav17Joshi committed Aug 30, 2023
1 parent fae3170 commit 8c64116
Showing 1 changed file with 115 additions and 70 deletions.
185 changes: 115 additions & 70 deletions stingray/modeling/gpmodeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ def get_kernel(kernel_type, kernel_params):

def get_mean(mean_type, mean_params):
"""
Function for producing the mean for the Gaussian Process.
Function for producing the mean function for the Gaussian Process.
Parameters
----------
Expand All @@ -106,118 +106,153 @@ def get_mean(mean_type, mean_params):
Dictionary containing the parameters for the mean
Should contain the parameters for the selected mean
Returns
-------
A function which takes in the time coordinates and returns the mean values.
Examples
--------
Unimodal Gaussian Mean Function:
mean_params = {"A": 3.0, "t0": 0.2, "sig1": 0.1, "sig2": 0.4}
mean = get_mean("gaussian", mean_params)
Multimodal Gaussian Mean Function:
mean_params = {"A": jnp.array([3.0, 4.0]), "t0": jnp.array([0.2, 1]),
"sig1": jnp.array([0.1, 0.4]), "sig2": jnp.array([0.4, 0.1])}
mean = get_mean("gaussian", mean_params)
"""
if not jax_avail:
raise ImportError("Jax is required")

if mean_type == "gaussian":
mean = functools.partial(_gaussian, mean_params=mean_params)
mean = functools.partial(_gaussian, params=mean_params)
elif mean_type == "exponential":
mean = functools.partial(_exponential, mean_params=mean_params)
mean = functools.partial(_exponential, params=mean_params)
elif mean_type == "constant":
mean = functools.partial(_constant, mean_params=mean_params)
mean = functools.partial(_constant, params=mean_params)
elif mean_type == "skew_gaussian":
mean = functools.partial(_skew_gaussian, mean_params=mean_params)
mean = functools.partial(_skew_gaussian, params=mean_params)
elif mean_type == "skew_exponential":
mean = functools.partial(_skew_exponential, mean_params=mean_params)
mean = functools.partial(_skew_exponential, params=mean_params)
elif mean_type == "fred":
mean = functools.partial(_fred, mean_params=mean_params)
mean = functools.partial(_fred, params=mean_params)
else:
raise ValueError("Mean type not implemented")
return mean


def _gaussian(t, mean_params):
def _gaussian(t, params):
"""A gaussian flare shape.
Parameters
----------
t: jnp.ndarray
The time coordinates.
A: jnp.int
Amplitude of the flare.
t0:
The location of the maximum.
sig1:
The width parameter for the gaussian.
params: dict
The dictionary contating parameter values of the gaussian flare.
The parameters for the gaussian flare are:
A: jnp.float / jnp.ndarray
Amplitude of the flare.
t0: jnp.float / jnp.ndarray
The location of the maximum.
sig1: jnp.float / jnp.ndarray
The width parameter for the gaussian.
Returns
-------
The y values for the gaussian flare.
"""
A = jnp.atleast_1d(mean_params["A"])[:, jnp.newaxis]
t0 = jnp.atleast_1d(mean_params["t0"])[:, jnp.newaxis]
sig = jnp.atleast_1d(mean_params["sig"])[:, jnp.newaxis]
A = jnp.atleast_1d(params["A"])[:, jnp.newaxis]
t0 = jnp.atleast_1d(params["t0"])[:, jnp.newaxis]
sig = jnp.atleast_1d(params["sig"])[:, jnp.newaxis]

return jnp.sum(A * jnp.exp(-((t - t0) ** 2) / (2 * (sig**2))), axis=0)


def _exponential(t, mean_params):
def _exponential(t, params):
"""An exponential flare shape.
Parameters
----------
t: jnp.ndarray
The time coordinates.
A: jnp.int
Amplitude of the flare.
t0:
The location of the maximum.
sig1:
The width parameter for the exponential.
params: dict
The dictionary contating parameter values of the exponential flare.
The parameters for the exponential flare are:
A: jnp.float / jnp.ndarray
Amplitude of the flare.
t0: jnp.float / jnp.ndarray
The location of the maximum.
sig1: jnp.float / jnp.ndarray
The width parameter for the exponential.
Returns
-------
The y values for exponential flare.
"""
A = jnp.atleast_1d(mean_params["A"])[:, jnp.newaxis]
t0 = jnp.atleast_1d(mean_params["t0"])[:, jnp.newaxis]
sig = jnp.atleast_1d(mean_params["sig"])[:, jnp.newaxis]
A = jnp.atleast_1d(params["A"])[:, jnp.newaxis]
t0 = jnp.atleast_1d(params["t0"])[:, jnp.newaxis]
sig = jnp.atleast_1d(params["sig"])[:, jnp.newaxis]

return jnp.sum(A * jnp.exp(-jnp.abs(t - t0) / (2 * (sig**2))), axis=0)


def _constant(t, mean_params):
def _constant(t, params):
"""A constant mean shape.
Parameters
----------
t: jnp.ndarray
The time coordinates.
A: jnp.int
Constant amplitude of the flare.
params: dict
The dictionary contating parameter values of the constant flare.
The parameters for the constant flare are:
A: jnp.float
Constant amplitude of the flare.
Returns
-------
The constant value.
"""
return mean_params["A"] * jnp.ones_like(t)
return params["A"] * jnp.ones_like(t)


def _skew_gaussian(t, mean_params):
def _skew_gaussian(t, params):
"""A skew gaussian flare shape.
Parameters
----------
t: jnp.ndarray
The time coordinates.
A: jnp.int
Amplitude of the flare.
t0:
The location of the maximum.
sig1:
The width parameter for the rising edge.
sig2:
The width parameter for the falling edge.
params: dict
The dictionary contating parameter values of the skew gaussian flare.
The parameters for the skew gaussian flare are:
A: jnp.float / jnp.ndarray
Amplitude of the flare.
t0: jnp.float / jnp.ndarray
The location of the maximum.
sig1: jnp.float / jnp.ndarray
The width parameter for the rising edge.
sig2: jnp.float / jnp.ndarray
The width parameter for the falling edge.
Returns
-------
The y values for skew gaussian flare.
"""
A = jnp.atleast_1d(mean_params["A"])[:, jnp.newaxis]
t0 = jnp.atleast_1d(mean_params["t0"])[:, jnp.newaxis]
sig1 = jnp.atleast_1d(mean_params["sig1"])[:, jnp.newaxis]
sig2 = jnp.atleast_1d(mean_params["sig2"])[:, jnp.newaxis]
A = jnp.atleast_1d(params["A"])[:, jnp.newaxis]
t0 = jnp.atleast_1d(params["t0"])[:, jnp.newaxis]
sig1 = jnp.atleast_1d(params["sig1"])[:, jnp.newaxis]
sig2 = jnp.atleast_1d(params["sig2"])[:, jnp.newaxis]

y = jnp.sum(
A
Expand All @@ -231,30 +266,35 @@ def _skew_gaussian(t, mean_params):
return y


def _skew_exponential(t, mean_params):
def _skew_exponential(t, params):
"""A skew exponential flare shape.
Parameters
----------
t: jnp.ndarray
The time coordinates.
A: jnp.int
Amplitude of the flare.
t0:
The location of the maximum.
sig1:
The width parameter for the rising edge.
sig2:
The width parameter for the falling edge.
params: dict
The dictionary contating parameter values of the skew exponential flare.
The parameters for the skew exponential flare are:
A: jnp.float / jnp.ndarray
Amplitude of the flare.
t0: jnp.float / jnp.ndarray
The location of the maximum.
sig1: jnp.float / jnp.ndarray
The width parameter for the rising edge.
sig2: jnp.float / jnp.ndarray
The width parameter for the falling edge.
Returns
-------
The y values for exponential flare.
"""
A = jnp.atleast_1d(mean_params["A"])[:, jnp.newaxis]
t0 = jnp.atleast_1d(mean_params["t0"])[:, jnp.newaxis]
sig1 = jnp.atleast_1d(mean_params["sig1"])[:, jnp.newaxis]
sig2 = jnp.atleast_1d(mean_params["sig2"])[:, jnp.newaxis]
A = jnp.atleast_1d(params["A"])[:, jnp.newaxis]
t0 = jnp.atleast_1d(params["t0"])[:, jnp.newaxis]
sig1 = jnp.atleast_1d(params["sig1"])[:, jnp.newaxis]
sig2 = jnp.atleast_1d(params["sig2"])[:, jnp.newaxis]

y = jnp.sum(
A
Expand All @@ -268,30 +308,35 @@ def _skew_exponential(t, mean_params):
return y


def _fred(t, mean_params):
def _fred(t, params):
"""A fast rise exponential decay (FRED) flare shape.
Parameters
----------
t: jnp.ndarray
The time coordinates.
A: jnp.int
Amplitude of the flare.
t0:
The location of the maximum.
phi:
Symmetry parameter of the flare.
delta:
Offset parameter of the flare.
params: dict
The dictionary contating parameter values of the FRED flare.
The parameters for the FRED flare are:
A: jnp.float / jnp.ndarray
Amplitude of the flare.
t0: jnp.float / jnp.ndarray
The location of the maximum.
phi: jnp.float / jnp.ndarray
Symmetry parameter of the flare.
delta: jnp.float / jnp.ndarray
Offset parameter of the flare.
Returns
-------
The y values for exponential flare.
"""
A = jnp.atleast_1d(mean_params["A"])[:, jnp.newaxis]
t0 = jnp.atleast_1d(mean_params["t0"])[:, jnp.newaxis]
phi = jnp.atleast_1d(mean_params["phi"])[:, jnp.newaxis]
delta = jnp.atleast_1d(mean_params["delta"])[:, jnp.newaxis]
A = jnp.atleast_1d(params["A"])[:, jnp.newaxis]
t0 = jnp.atleast_1d(params["t0"])[:, jnp.newaxis]
phi = jnp.atleast_1d(params["phi"])[:, jnp.newaxis]
delta = jnp.atleast_1d(params["delta"])[:, jnp.newaxis]

return jnp.sum(
A * jnp.exp(-phi * ((t + delta) / t0 + t0 / (t + delta))) * jnp.exp(2 * phi), axis=0
Expand Down

0 comments on commit 8c64116

Please sign in to comment.