Skip to content

Commit

Permalink
Added log parameters
Browse files Browse the repository at this point in the history
  • Loading branch information
Gaurav17Joshi committed Aug 24, 2023
1 parent b4622ad commit de35cb5
Show file tree
Hide file tree
Showing 2 changed files with 90 additions and 93 deletions.
39 changes: 26 additions & 13 deletions stingray/modeling/gpmodeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,15 +301,20 @@ def _get_kernel_params(kernel_type):
----------
kernel_type: string
The type of kernel to be used for the Gaussian Process model
The parameters in log scale have a prefix of "log_"
Returns
-------
A list of the parameters for the kernel for the GP model
"""
if kernel_type == "RN":
return ["arn", "crn"]
return ["log_arn", "log_crn"]
elif kernel_type == "QPO_plus_RN":
return ["arn", "crn", "aqpo", "cqpo", "freq"]
return ["log_arn", "log_crn", "log_aqpo", "log_cqpo", "log_freq"]
elif kernel_type == "QPO":
return ["log_aqpo", "log_cqpo", "log_freq"]
else:
raise ValueError("Kernel type not implemented")


def _get_mean_params(mean_type):
Expand All @@ -320,19 +325,22 @@ def _get_mean_params(mean_type):
----------
mean_type: string
The type of mean to be used for the Gaussian Process model
The parameters in log scale have a prefix of "log_"
Returns
-------
A list of the parameters for the mean for the GP model
"""
if (mean_type == "gaussian") or (mean_type == "exponential"):
return ["A", "t0", "sig"]
return ["log_A", "t0", "log_sig"]
elif mean_type == "constant":
return ["A"]
return ["log_A"]
elif (mean_type == "skew_gaussian") or (mean_type == "skew_exponential"):
return ["A", "t0", "sig1", "sig2"]
return ["log_A", "t0", "log_sig1", "log_sig2"]
elif mean_type == "fred":
return ["A", "t0", "delta", "phi"]
return ["log_A", "t0", "delta", "phi"]
else:
raise ValueError("Mean type not implemented")


def get_gp_params(kernel_type, mean_type):
Expand All @@ -355,7 +363,7 @@ def get_gp_params(kernel_type, mean_type):
Examples
--------
get_gp_params("QPO_plus_RN", "gaussian")
['arn', 'crn', 'aqpo', 'cqpo', 'freq', 'A', 't0', 'sig']
['log_arn', 'log_crn', 'log_aqpo', 'log_cqpo', 'log_freq', 'log_A', 't0', 'log_sig']
"""
kernel_params = _get_kernel_params(kernel_type)
mean_params = _get_mean_params(mean_type)
Expand All @@ -381,6 +389,7 @@ def get_prior(params_list, prior_dict):
or special priors from jaxns.
**Note**: If jaxns priors are used, then the name given to them should be the same as
the corresponding name in the params_list.
Also, if a parameter is to be used in the log scale, it should have a prefix of "log_"
Returns
-------
Expand All @@ -403,11 +412,11 @@ def get_prior(params_list, prior_dict):
Make a prior dictionary using tensorflow_probability distributions
prior_dict = {
"A": tfpd.Uniform(low = 1e-1, high = 2e+2),
"log_A": tfpd.Uniform(low = jnp.log(1e-1), high = jnp.log(2e+2)),
"t0": tfpd.Uniform(low = 0.0 - 0.1, high = 1 + 0.1),
"sig": tfpd.Uniform(low = 0.5 * 1 / 20, high = 2 ),
"arn": tfpd.Uniform(low = 0.1 , high = 2 ),
"crn": tfpd.Uniform(low = jnp.log(1 /5), high = jnp.log(20)),
"log_sig": tfpd.Uniform(low = jnp.log(0.5 * 1 / 20), high = jnp.log(2) ),
"log_arn": tfpd.Uniform(low = jnp.log(0.1) , high = jnp.log(2) ),
"log_crn": tfpd.Uniform(low = jnp.log(1 /5), high = jnp.log(20)),
}
prior_model = get_prior(params_list, prior_dict)
Expand Down Expand Up @@ -441,7 +450,8 @@ def get_likelihood(params_list, kernel_type, mean_type, **kwargs):
Makes a jaxns specific log likelihood function which takes in the
parameters in the order of the parameters list, and calculates the
log likelihood of the data given the parameters, and the model
(kernel, mean) of the GP model.
(kernel, mean) of the GP model. **Note** Any parameters with a prefix
of "log_" are taken to be in the log scale.
Parameters
----------
Expand Down Expand Up @@ -481,7 +491,10 @@ def get_likelihood(params_list, kernel_type, mean_type, **kwargs):
def likelihood_model(*args):
dict = {}
for i, params in enumerate(params_list):
dict[params] = args[i]
if params[0:4] == "log_":
dict[params[4:]] = jnp.exp(args[i])
else:
dict[params] = args[i]
kernel = get_kernel(kernel_type=kernel_type, kernel_params=dict)
mean = get_mean(mean_type=mean_type, mean_params=dict)
gp = GaussianProcess(kernel, kwargs["Times"], mean_value=mean(kwargs["Times"]))
Expand Down
144 changes: 64 additions & 80 deletions stingray/modeling/tests/test_gpmodeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,81 +173,65 @@ def setup_class(self):
pass

def test_get_gp_params_rn(self):
assert get_gp_params("RN", "gaussian") == ["arn", "crn", "A", "t0", "sig"]
assert get_gp_params("RN", "constant") == ["arn", "crn", "A"]
assert get_gp_params("RN", "skew_gaussian") == ["arn", "crn", "A", "t0", "sig1", "sig2"]
assert get_gp_params("RN", "skew_exponential") == [
"arn",
"crn",
"A",
assert get_gp_params("RN", "gaussian") == ["log_arn", "log_crn", "log_A", "t0", "log_sig"]
assert get_gp_params("RN", "constant") == ["log_arn", "log_crn", "log_A"]
assert get_gp_params("RN", "skew_gaussian") == [
"log_arn",
"log_crn",
"log_A",
"t0",
"sig1",
"sig2",
"log_sig1",
"log_sig2",
]
assert get_gp_params("RN", "exponential") == ["arn", "crn", "A", "t0", "sig"]
assert get_gp_params("RN", "fred") == ["arn", "crn", "A", "t0", "delta", "phi"]

def test_get_gp_params_qpo_plus_rn(self):
assert get_gp_params("QPO_plus_RN", "gaussian") == [
"arn",
"crn",
"aqpo",
"cqpo",
"freq",
"A",
assert get_gp_params("RN", "skew_exponential") == [
"log_arn",
"log_crn",
"log_A",
"t0",
"sig",
]
assert get_gp_params("QPO_plus_RN", "constant") == [
"arn",
"crn",
"aqpo",
"cqpo",
"freq",
"A",
"log_sig1",
"log_sig2",
]
assert get_gp_params("QPO_plus_RN", "skew_gaussian") == [
"arn",
"crn",
"aqpo",
"cqpo",
"freq",
"A",
assert get_gp_params("RN", "exponential") == [
"log_arn",
"log_crn",
"log_A",
"t0",
"sig1",
"sig2",
"log_sig",
]
assert get_gp_params("QPO_plus_RN", "skew_exponential") == [
"arn",
"crn",
"aqpo",
"cqpo",
"freq",
"A",
assert get_gp_params("RN", "fred") == [
"log_arn",
"log_crn",
"log_A",
"t0",
"sig1",
"sig2",
"delta",
"phi",
]
assert get_gp_params("QPO_plus_RN", "exponential") == [
"arn",
"crn",
"aqpo",
"cqpo",
"freq",
"A",

def test_get_gp_params_qpo_plus_rn(self):
assert get_gp_params("QPO_plus_RN", "gaussian") == [
"log_arn",
"log_crn",
"log_aqpo",
"log_cqpo",
"log_freq",
"log_A",
"t0",
"sig",
"log_sig",
]
assert get_gp_params("QPO_plus_RN", "fred") == [
"arn",
"crn",
"aqpo",
"cqpo",
"freq",
"A",
with pytest.raises(ValueError, match="Mean type not implemented"):
get_gp_params("QPO_plus_RN", "notimplemented")

with pytest.raises(ValueError, match="Kernel type not implemented"):
get_gp_params("notimplemented", "gaussian")

def test_get_qpo(self):
assert get_gp_params("QPO", "gaussian") == [
"log_aqpo",
"log_cqpo",
"log_freq",
"log_A",
"t0",
"delta",
"phi",
"log_sig",
]


Expand Down Expand Up @@ -278,11 +262,11 @@ def setup_class(self):

# The prior dictionary, with suitable tfpd prior distributions
prior_dict = {
"A": tfpd.Uniform(low=0.1 * span, high=2 * span),
"log_A": tfpd.Uniform(low=jnp.log(0.1 * span), high=jnp.log(2 * span)),
"t0": tfpd.Uniform(low=self.Times[0] - 0.1 * T, high=self.Times[-1] + 0.1 * T),
"sig": tfpd.Uniform(low=0.5 * 1 / f, high=2 * T),
"arn": tfpd.Uniform(low=0.1 * span, high=2 * span),
"crn": tfpd.Uniform(low=jnp.log(1 / T), high=jnp.log(f)),
"log_sig": tfpd.Uniform(low=jnp.log(0.5 * 1 / f), high=jnp.log(2 * T)),
"log_arn": tfpd.Uniform(low=jnp.log(0.1 * span), high=jnp.log(2 * span)),
"log_crn": tfpd.Uniform(low=jnp.log(1 / T), high=jnp.log(f)),
}

prior_model = get_prior(self.params_list, prior_dict)
Expand Down Expand Up @@ -339,15 +323,15 @@ def test_max_likelihood_parameters(self):
assert key in self.gpresult.get_max_likelihood_parameters()

def test_posterior_plot(self):
self.gpresult.posterior_plot("A")
self.gpresult.posterior_plot("log_A")
assert plt.fignum_exists(1)

def test_posterior_plot_labels_and_fname_default(self):
clear_all_figs()
outfname = "A_Posterior_plot.png"
outfname = "log_A_Posterior_plot.png"
if os.path.exists(outfname):
os.unlink(outfname)
self.gpresult.posterior_plot("A", save=True)
self.gpresult.posterior_plot("log_A", save=True)
assert os.path.exists(outfname)
os.unlink(outfname)

Expand All @@ -356,20 +340,20 @@ def test_posterior_plot_labels_and_fname(self):
outfname = "blabla.png"
if os.path.exists(outfname):
os.unlink(outfname)
self.gpresult.posterior_plot("A", axis=[0, 14, 0, 0.5], save=True, filename=outfname)
self.gpresult.posterior_plot("log_A", axis=[0, 14, 0, 0.5], save=True, filename=outfname)
assert os.path.exists(outfname)
os.unlink(outfname)

def test_weighted_posterior_plot(self):
self.gpresult.weighted_posterior_plot("A")
self.gpresult.weighted_posterior_plot("log_A")
assert plt.fignum_exists(1)

def test_weighted_posterior_plot_labels_and_fname_default(self):
clear_all_figs()
outfname = "A_Weighted_Posterior_plot.png"
outfname = "log_A_Weighted_Posterior_plot.png"
if os.path.exists(outfname):
os.unlink(outfname)
self.gpresult.weighted_posterior_plot("A", save=True)
self.gpresult.weighted_posterior_plot("log_A", save=True)
assert os.path.exists(outfname)
os.unlink(outfname)

Expand All @@ -379,21 +363,21 @@ def test_weighted_posterior_plot_labels_and_fname(self):
if os.path.exists(outfname):
os.unlink(outfname)
self.gpresult.weighted_posterior_plot(
"A", axis=[0, 14, 0, 0.5], save=True, filename=outfname
"log_A", axis=[0, 14, 0, 0.5], save=True, filename=outfname
)
assert os.path.exists(outfname)
os.unlink(outfname)

def test_corner_plot(self):
self.gpresult.corner_plot("A", "t0")
self.gpresult.corner_plot("log_A", "t0")
assert plt.fignum_exists(1)

def test_corner_plot_labels_and_fname_default(self):
clear_all_figs()
outfname = "A_t0_Corner_plot.png"
outfname = "log_A_t0_Corner_plot.png"
if os.path.exists(outfname):
os.unlink(outfname)
self.gpresult.corner_plot("A", "t0", save=True)
self.gpresult.corner_plot("log_A", "t0", save=True)
assert os.path.exists(outfname)
os.unlink(outfname)

Expand All @@ -402,6 +386,6 @@ def test_corner_plot_labels_and_fname(self):
outfname = "blabla.png"
if os.path.exists(outfname):
os.unlink(outfname)
self.gpresult.corner_plot("A", "t0", axis=[0, 0.5, 0, 5], save=True, filename=outfname)
self.gpresult.corner_plot("log_A", "t0", axis=[0, 0.5, 0, 5], save=True, filename=outfname)
assert os.path.exists(outfname)
os.unlink(outfname)

0 comments on commit de35cb5

Please sign in to comment.