From d4519ebcdc9b3dab0eadfa4d492e057cbfb9c4b0 Mon Sep 17 00:00:00 2001 From: Gaurav17Joshi Date: Sun, 20 Aug 2023 17:54:38 +0530 Subject: [PATCH] Added log parameters --- stingray/modeling/gpmodeling.py | 39 ++++-- stingray/modeling/tests/test_gpmodeling.py | 144 +++++++++------------ 2 files changed, 90 insertions(+), 93 deletions(-) diff --git a/stingray/modeling/gpmodeling.py b/stingray/modeling/gpmodeling.py index f85f13f3f..134ddf145 100644 --- a/stingray/modeling/gpmodeling.py +++ b/stingray/modeling/gpmodeling.py @@ -301,15 +301,20 @@ def _get_kernel_params(kernel_type): ---------- kernel_type: string The type of kernel to be used for the Gaussian Process model + The parameters in log scale have a prefix of "log_" Returns ------- A list of the parameters for the kernel for the GP model """ if kernel_type == "RN": - return ["arn", "crn"] + return ["log_arn", "log_crn"] elif kernel_type == "QPO_plus_RN": - return ["arn", "crn", "aqpo", "cqpo", "freq"] + return ["log_arn", "log_crn", "log_aqpo", "log_cqpo", "log_freq"] + elif kernel_type == "QPO": + return ["log_aqpo", "log_cqpo", "log_freq"] + else: + raise ValueError("Kernel type not implemented") def _get_mean_params(mean_type): @@ -320,19 +325,22 @@ def _get_mean_params(mean_type): ---------- mean_type: string The type of mean to be used for the Gaussian Process model + The parameters in log scale have a prefix of "log_" Returns ------- A list of the parameters for the mean for the GP model """ if (mean_type == "gaussian") or (mean_type == "exponential"): - return ["A", "t0", "sig"] + return ["log_A", "t0", "log_sig"] elif mean_type == "constant": - return ["A"] + return ["log_A"] elif (mean_type == "skew_gaussian") or (mean_type == "skew_exponential"): - return ["A", "t0", "sig1", "sig2"] + return ["log_A", "t0", "log_sig1", "log_sig2"] elif mean_type == "fred": - return ["A", "t0", "delta", "phi"] + return ["log_A", "t0", "delta", "phi"] + else: + raise ValueError("Mean type not implemented") def get_gp_params(kernel_type, mean_type): @@ -355,7 +363,7 @@ def get_gp_params(kernel_type, mean_type): Examples -------- get_gp_params("QPO_plus_RN", "gaussian") - ['arn', 'crn', 'aqpo', 'cqpo', 'freq', 'A', 't0', 'sig'] + ['log_arn', 'log_crn', 'log_aqpo', 'log_cqpo', 'log_freq', 'log_A', 't0', 'log_sig'] """ kernel_params = _get_kernel_params(kernel_type) mean_params = _get_mean_params(mean_type) @@ -381,6 +389,7 @@ def get_prior(params_list, prior_dict): or special priors from jaxns. **Note**: If jaxns priors are used, then the name given to them should be the same as the corresponding name in the params_list. + Also, if a parameter is to be used in the log scale, it should have a prefix of "log_" Returns ------- @@ -403,11 +412,11 @@ def get_prior(params_list, prior_dict): Make a prior dictionary using tensorflow_probability distributions prior_dict = { - "A": tfpd.Uniform(low = 1e-1, high = 2e+2), + "log_A": tfpd.Uniform(low = jnp.log(1e-1), high = jnp.log(2e+2)), "t0": tfpd.Uniform(low = 0.0 - 0.1, high = 1 + 0.1), - "sig": tfpd.Uniform(low = 0.5 * 1 / 20, high = 2 ), - "arn": tfpd.Uniform(low = 0.1 , high = 2 ), - "crn": tfpd.Uniform(low = jnp.log(1 /5), high = jnp.log(20)), + "log_sig": tfpd.Uniform(low = jnp.log(0.5 * 1 / 20), high = jnp.log(2) ), + "log_arn": tfpd.Uniform(low = jnp.log(0.1) , high = jnp.log(2) ), + "log_crn": tfpd.Uniform(low = jnp.log(1 /5), high = jnp.log(20)), } prior_model = get_prior(params_list, prior_dict) @@ -441,7 +450,8 @@ def get_likelihood(params_list, kernel_type, mean_type, **kwargs): Makes a jaxns specific log likelihood function which takes in the parameters in the order of the parameters list, and calculates the log likelihood of the data given the parameters, and the model - (kernel, mean) of the GP model. + (kernel, mean) of the GP model. **Note** Any parameters with a prefix + of "log_" are taken to be in the log scale. Parameters ---------- @@ -481,7 +491,10 @@ def get_likelihood(params_list, kernel_type, mean_type, **kwargs): def likelihood_model(*args): dict = {} for i, params in enumerate(params_list): - dict[params] = args[i] + if params[0:4] == "log_": + dict[params[4:]] = jnp.exp(args[i]) + else: + dict[params] = args[i] kernel = get_kernel(kernel_type=kernel_type, kernel_params=dict) mean = get_mean(mean_type=mean_type, mean_params=dict) gp = GaussianProcess(kernel, kwargs["Times"], mean_value=mean(kwargs["Times"])) diff --git a/stingray/modeling/tests/test_gpmodeling.py b/stingray/modeling/tests/test_gpmodeling.py index e65dff3f8..1d7f5eecd 100644 --- a/stingray/modeling/tests/test_gpmodeling.py +++ b/stingray/modeling/tests/test_gpmodeling.py @@ -173,81 +173,65 @@ def setup_class(self): pass def test_get_gp_params_rn(self): - assert get_gp_params("RN", "gaussian") == ["arn", "crn", "A", "t0", "sig"] - assert get_gp_params("RN", "constant") == ["arn", "crn", "A"] - assert get_gp_params("RN", "skew_gaussian") == ["arn", "crn", "A", "t0", "sig1", "sig2"] - assert get_gp_params("RN", "skew_exponential") == [ - "arn", - "crn", - "A", + assert get_gp_params("RN", "gaussian") == ["log_arn", "log_crn", "log_A", "t0", "log_sig"] + assert get_gp_params("RN", "constant") == ["log_arn", "log_crn", "log_A"] + assert get_gp_params("RN", "skew_gaussian") == [ + "log_arn", + "log_crn", + "log_A", "t0", - "sig1", - "sig2", + "log_sig1", + "log_sig2", ] - assert get_gp_params("RN", "exponential") == ["arn", "crn", "A", "t0", "sig"] - assert get_gp_params("RN", "fred") == ["arn", "crn", "A", "t0", "delta", "phi"] - - def test_get_gp_params_qpo_plus_rn(self): - assert get_gp_params("QPO_plus_RN", "gaussian") == [ - "arn", - "crn", - "aqpo", - "cqpo", - "freq", - "A", + assert get_gp_params("RN", "skew_exponential") == [ + "log_arn", + "log_crn", + "log_A", "t0", - "sig", - ] - assert get_gp_params("QPO_plus_RN", "constant") == [ - "arn", - "crn", - "aqpo", - "cqpo", - "freq", - "A", + "log_sig1", + "log_sig2", ] - assert get_gp_params("QPO_plus_RN", "skew_gaussian") == [ - "arn", - "crn", - "aqpo", - "cqpo", - "freq", - "A", + assert get_gp_params("RN", "exponential") == [ + "log_arn", + "log_crn", + "log_A", "t0", - "sig1", - "sig2", + "log_sig", ] - assert get_gp_params("QPO_plus_RN", "skew_exponential") == [ - "arn", - "crn", - "aqpo", - "cqpo", - "freq", - "A", + assert get_gp_params("RN", "fred") == [ + "log_arn", + "log_crn", + "log_A", "t0", - "sig1", - "sig2", + "delta", + "phi", ] - assert get_gp_params("QPO_plus_RN", "exponential") == [ - "arn", - "crn", - "aqpo", - "cqpo", - "freq", - "A", + + def test_get_gp_params_qpo_plus_rn(self): + assert get_gp_params("QPO_plus_RN", "gaussian") == [ + "log_arn", + "log_crn", + "log_aqpo", + "log_cqpo", + "log_freq", + "log_A", "t0", - "sig", + "log_sig", ] - assert get_gp_params("QPO_plus_RN", "fred") == [ - "arn", - "crn", - "aqpo", - "cqpo", - "freq", - "A", + with pytest.raises(ValueError, match="Mean type not implemented"): + get_gp_params("QPO_plus_RN", "notimplemented") + + with pytest.raises(ValueError, match="Kernel type not implemented"): + get_gp_params("notimplemented", "gaussian") + + def test_get_qpo(self): + assert get_gp_params("QPO", "gaussian") == [ + "log_aqpo", + "log_cqpo", + "log_freq", + "log_A", "t0", - "delta", - "phi", + "log_sig", ] @@ -278,11 +262,11 @@ def setup_class(self): # The prior dictionary, with suitable tfpd prior distributions prior_dict = { - "A": tfpd.Uniform(low=0.1 * span, high=2 * span), + "log_A": tfpd.Uniform(low=jnp.log(0.1 * span), high=jnp.log(2 * span)), "t0": tfpd.Uniform(low=self.Times[0] - 0.1 * T, high=self.Times[-1] + 0.1 * T), - "sig": tfpd.Uniform(low=0.5 * 1 / f, high=2 * T), - "arn": tfpd.Uniform(low=0.1 * span, high=2 * span), - "crn": tfpd.Uniform(low=jnp.log(1 / T), high=jnp.log(f)), + "log_sig": tfpd.Uniform(low=jnp.log(0.5 * 1 / f), high=jnp.log(2 * T)), + "log_arn": tfpd.Uniform(low=jnp.log(0.1 * span), high=jnp.log(2 * span)), + "log_crn": tfpd.Uniform(low=jnp.log(1 / T), high=jnp.log(f)), } prior_model = get_prior(self.params_list, prior_dict) @@ -339,15 +323,15 @@ def test_max_likelihood_parameters(self): assert key in self.gpresult.get_max_likelihood_parameters() def test_posterior_plot(self): - self.gpresult.posterior_plot("A") + self.gpresult.posterior_plot("log_A") assert plt.fignum_exists(1) def test_posterior_plot_labels_and_fname_default(self): clear_all_figs() - outfname = "A_Posterior_plot.png" + outfname = "log_A_Posterior_plot.png" if os.path.exists(outfname): os.unlink(outfname) - self.gpresult.posterior_plot("A", save=True) + self.gpresult.posterior_plot("log_A", save=True) assert os.path.exists(outfname) os.unlink(outfname) @@ -356,20 +340,20 @@ def test_posterior_plot_labels_and_fname(self): outfname = "blabla.png" if os.path.exists(outfname): os.unlink(outfname) - self.gpresult.posterior_plot("A", axis=[0, 14, 0, 0.5], save=True, filename=outfname) + self.gpresult.posterior_plot("log_A", axis=[0, 14, 0, 0.5], save=True, filename=outfname) assert os.path.exists(outfname) os.unlink(outfname) def test_weighted_posterior_plot(self): - self.gpresult.weighted_posterior_plot("A") + self.gpresult.weighted_posterior_plot("log_A") assert plt.fignum_exists(1) def test_weighted_posterior_plot_labels_and_fname_default(self): clear_all_figs() - outfname = "A_Weighted_Posterior_plot.png" + outfname = "log_A_Weighted_Posterior_plot.png" if os.path.exists(outfname): os.unlink(outfname) - self.gpresult.weighted_posterior_plot("A", save=True) + self.gpresult.weighted_posterior_plot("log_A", save=True) assert os.path.exists(outfname) os.unlink(outfname) @@ -379,21 +363,21 @@ def test_weighted_posterior_plot_labels_and_fname(self): if os.path.exists(outfname): os.unlink(outfname) self.gpresult.weighted_posterior_plot( - "A", axis=[0, 14, 0, 0.5], save=True, filename=outfname + "log_A", axis=[0, 14, 0, 0.5], save=True, filename=outfname ) assert os.path.exists(outfname) os.unlink(outfname) def test_corner_plot(self): - self.gpresult.corner_plot("A", "t0") + self.gpresult.corner_plot("log_A", "t0") assert plt.fignum_exists(1) def test_corner_plot_labels_and_fname_default(self): clear_all_figs() - outfname = "A_t0_Corner_plot.png" + outfname = "log_A_t0_Corner_plot.png" if os.path.exists(outfname): os.unlink(outfname) - self.gpresult.corner_plot("A", "t0", save=True) + self.gpresult.corner_plot("log_A", "t0", save=True) assert os.path.exists(outfname) os.unlink(outfname) @@ -402,6 +386,6 @@ def test_corner_plot_labels_and_fname(self): outfname = "blabla.png" if os.path.exists(outfname): os.unlink(outfname) - self.gpresult.corner_plot("A", "t0", axis=[0, 0.5, 0, 5], save=True, filename=outfname) + self.gpresult.corner_plot("log_A", "t0", axis=[0, 0.5, 0, 5], save=True, filename=outfname) assert os.path.exists(outfname) os.unlink(outfname)