Skip to content

Commit

Permalink
revert some more and clean
Browse files Browse the repository at this point in the history
  • Loading branch information
veni-vidi-vici-dormivi committed Aug 22, 2024
1 parent 37d13ae commit cf04eb8
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 23 deletions.
23 changes: 8 additions & 15 deletions mesmer/stats/_auto_regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,6 @@ def _fit_auto_regression_scen_ens(*objs, dim, ens_dim, lags):
Dimension along which to fit the auto regression.
ens_dim : str
Dimension name of the ensemble members, None if no ensemble is provided.
If provided, must also have coordinates.
lags : int
The number of lags to include in the model.
Expand All @@ -91,25 +90,19 @@ def _fit_auto_regression_scen_ens(*objs, dim, ens_dim, lags):
"""

ar_params_scen = list()

for obj in objs:
ar_params = fit_auto_regression(obj, dim=dim, lags=int(lags))

#TODO: think about weighting! see https://github.com/MESMER-group/mesmer/issues/307
if ens_dim in ar_params.dims:
ar_params = ar_params.mean(ens_dim)

ar_params_scen.append(ar_params)

ar_params_scen = xr.concat(ar_params_scen, dim="scen")

# TODO: think about weighting! see https://github.com/MESMER-group/mesmer/issues/307
if ens_dim in ar_params.dims:
# mean over ensemble members
ar_params_scen = ar_params_scen.mean(dim=ens_dim)

# mean over scenarios
ar_params = ar_params_scen.mean(dim="scen")

# clean up
ar_params = ar_params.drop_vars("nobs") # don't need it in the result
if ens_dim in ar_params.dims:
ar_params = ar_params.drop_dims(ens_dim)

# return the mean over all scenarios
ar_params = ar_params_scen.mean("scen")

return ar_params

Expand Down
9 changes: 1 addition & 8 deletions tests/unit/test_auto_regression_scen_ens.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,12 +73,7 @@ def test_fit_auto_regression_scen_ens_one_scen(std):
)

expected = mesmer.stats.fit_auto_regression(da, dim="time", lags=3)
expected["variance"] = sum(expected.variance * (expected.nobs - 1)) / sum(
expected.nobs - 1
)
expected["coeffs"] = expected.coeffs.mean("ens")
expected["intercept"] = expected.intercept.mean("ens")
expected = expected.drop_vars(["nobs", "ens"])
expected = expected.mean("ens")

xr.testing.assert_allclose(result, expected)
np.testing.assert_allclose(np.sqrt(result.variance), std, rtol=1e-1)
Expand All @@ -98,7 +93,6 @@ def test_fit_auto_regression_scen_ens_multi_scen():
expected = mesmer.stats.fit_auto_regression(da, dim="time", lags=3)
expected = expected.unstack("scen_ens")
expected = expected.mean("ens").mean("scen")
expected = expected.drop_vars(["nobs"])

xr.testing.assert_equal(result, expected)

Expand All @@ -112,6 +106,5 @@ def test_fit_auto_regression_scen_ens_no_ens_dim():
)

expected = mesmer.stats.fit_auto_regression(da, dim="time", lags=3)
expected = expected.drop_vars(["nobs"])

xr.testing.assert_allclose(result, expected)

0 comments on commit cf04eb8

Please sign in to comment.