diff --git a/src/ensemble/distributions.py b/src/ensemble/distributions.py index 750a556..37fdba6 100644 --- a/src/ensemble/distributions.py +++ b/src/ensemble/distributions.py @@ -42,27 +42,45 @@ def _create_scipy_dist(self) -> None: class InvGamma(Distribution): def _create_scipy_dist(self) -> None: - res = scipy.optimize.minimize( + optim_params = scipy.optimize.minimize( fun=self._shape_scale, # a *good* friend told me that this is a good initial guess and it works so far??? + # alpha = 3 is because alpha > 2 must be true due to variance formula + # beta = mean * (alpha - 1) after isolating beta from formula for mean x0=[3, self.mean * 2], args=(self.mean, self.variance), ) - print("results from minimizer: ", res.x) - shape, scale = np.abs(res.x) + shape, scale = np.abs(optim_params.x) self._scipy_dist = scipy.stats.invgamma(a=shape, scale=scale) def _shape_scale(self, x, samp_mean, samp_var) -> None: alpha = x[0] beta = x[1] - return ((beta / (alpha - 1)) - samp_mean) ** 2 + ( - (beta**2 / ((alpha - 1) ** 2 * (alpha - 2))) - samp_var - ) ** 2 + mean_guess = beta / (alpha - 1) + variance_guess = beta**2 / ((alpha - 1) ** 2 * (alpha - 2)) + return (mean_guess - samp_mean) ** 2 + (variance_guess - samp_var) ** 2 class Fisk(Distribution): def _create_scipy_dist(self): - raise NotImplementedError + optim_params = scipy.optimize.minimize( + fun=self._shape_scale, + x0=[2, self.mean * 2 / np.pi * np.sin(np.pi / 2)], + args=(self.mean, self.variance), + ) + shape, scale = np.abs(optim_params.x) + print("parameters from optimizer: ", shape, scale) + self._scipy_dist = scipy.stats.fisk(c=shape, scale=scale) + + def _shape_scale(self, x, samp_mean, samp_var) -> None: + alpha = x[0] + beta = x[1] + b = np.pi / beta + mean_guess = alpha * b / np.sin(b) + variance_guess = alpha**2 * ( + (2 * b / np.sin(2 * b)) - b**2 / np.sin(b) ** 2 + ) + return (mean_guess - samp_mean) ** 2 + (variance_guess - samp_var) ** 2 class GumbelR(Distribution): diff --git a/tests/test_distributions.py b/tests/test_distributions.py index 9a9911c..a679e4d 100644 --- a/tests/test_distributions.py +++ b/tests/test_distributions.py @@ -40,16 +40,19 @@ def test_gamma(): def test_invgamma(): - # raise NotImplementedError invgamma = InvGamma(MEAN, VARIANCE) res = invgamma.stats(moments="mv") - print("mean and var: ", res) assert np.isclose(res[0], MEAN) assert np.isclose(res[1], VARIANCE) def test_fisk(): - raise NotImplementedError + fisk = Fisk(MEAN, VARIANCE) + res = fisk.stats(moments="mv") + print("est mean and var: ", res) + assert False + # assert np.isclose(res[0], MEAN) + # assert np.isclose(res[1], VARIANCE) def test_gumbel():