Skip to content

Commit

Permalink
first attempt at implmenting fisk
Browse files Browse the repository at this point in the history
  • Loading branch information
mbi6245 committed Jul 22, 2024
1 parent ca701e0 commit 2fbf37a
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 10 deletions.
32 changes: 25 additions & 7 deletions src/ensemble/distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,27 +42,45 @@ def _create_scipy_dist(self) -> None:

class InvGamma(Distribution):
def _create_scipy_dist(self) -> None:
res = scipy.optimize.minimize(
optim_params = scipy.optimize.minimize(
fun=self._shape_scale,
# a *good* friend told me that this is a good initial guess and it works so far???
# alpha = 3 is because alpha > 2 must be true due to variance formula
# beta = mean * (alpha - 1) after isolating beta from formula for mean
x0=[3, self.mean * 2],
args=(self.mean, self.variance),
)
print("results from minimizer: ", res.x)
shape, scale = np.abs(res.x)
shape, scale = np.abs(optim_params.x)
self._scipy_dist = scipy.stats.invgamma(a=shape, scale=scale)

def _shape_scale(self, x, samp_mean, samp_var) -> None:
alpha = x[0]
beta = x[1]
return ((beta / (alpha - 1)) - samp_mean) ** 2 + (
(beta**2 / ((alpha - 1) ** 2 * (alpha - 2))) - samp_var
) ** 2
mean_guess = beta / (alpha - 1)
variance_guess = beta**2 / ((alpha - 1) ** 2 * (alpha - 2))
return (mean_guess - samp_mean) ** 2 + (variance_guess - samp_var) ** 2


class Fisk(Distribution):
def _create_scipy_dist(self):
raise NotImplementedError
optim_params = scipy.optimize.minimize(
fun=self._shape_scale,
x0=[2, self.mean * 2 / np.pi * np.sin(np.pi / 2)],
args=(self.mean, self.variance),
)
shape, scale = np.abs(optim_params.x)
print("parameters from optimizer: ", shape, scale)
self._scipy_dist = scipy.stats.fisk(c=shape, scale=scale)

def _shape_scale(self, x, samp_mean, samp_var) -> None:
alpha = x[0]
beta = x[1]
b = np.pi / beta
mean_guess = alpha * b / np.sin(b)
variance_guess = alpha**2 * (
(2 * b / np.sin(2 * b)) - b**2 / np.sin(b) ** 2
)
return (mean_guess - samp_mean) ** 2 + (variance_guess - samp_var) ** 2


class GumbelR(Distribution):
Expand Down
9 changes: 6 additions & 3 deletions tests/test_distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,16 +40,19 @@ def test_gamma():


def test_invgamma():
# raise NotImplementedError
invgamma = InvGamma(MEAN, VARIANCE)
res = invgamma.stats(moments="mv")
print("mean and var: ", res)
assert np.isclose(res[0], MEAN)
assert np.isclose(res[1], VARIANCE)


def test_fisk():
raise NotImplementedError
fisk = Fisk(MEAN, VARIANCE)
res = fisk.stats(moments="mv")
print("est mean and var: ", res)
assert False
# assert np.isclose(res[0], MEAN)
# assert np.isclose(res[1], VARIANCE)


def test_gumbel():
Expand Down

0 comments on commit 2fbf37a

Please sign in to comment.