Skip to content

Commit

Permalink
auto_regression: return covariance (#309)
Browse files Browse the repository at this point in the history
* auto_regression: return covariance

* add reference to #307

* CHANGELOG
  • Loading branch information
mathause authored Sep 27, 2023
1 parent 70a6a0e commit c31044f
Show file tree
Hide file tree
Showing 5 changed files with 22 additions and 16 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,9 @@ New Features
- Add ``mesmer.stats.auto_regression._fit_auto_regression_xr``: xarray wrapper to fit an
auto regression model (`#139 <https://github.com/MESMER-group/mesmer/pull/139>`_).
By `Mathias Hauser`_.
- Have ``mesmer.stats.auto_regression._fit_auto_regression_xr`` return the covariance instead
of the standard deviation (`#306 <https://github.com/MESMER-group/mesmer/issues/306>`_).
By `Mathias Hauser`_.
- Add ``mesmer.stats.auto_regression._draw_auto_regression_correlated_np``: to draw samples of an
auto regression model (`#161 <https://github.com/MESMER-group/mesmer/pull/161>`_).
By `Mathias Hauser`_.
Expand Down
3 changes: 3 additions & 0 deletions mesmer/calibrate_mesmer/train_gv.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,9 @@ def train_gv_AR(params_gv, gv, max_lag, sel_crit):
data = xr.DataArray(data, dims=("run", "time"))

params = _fit_auto_regression_xr(data, dim="time", lags=AR_order_sel)
# BUG/ TODO: we wrongfully average over the standard_deviation
# see https://github.com/MESMER-group/mesmer/issues/307
params["standard_deviation"] = np.sqrt(params.covariance)
params = params.mean("run")

params_scen.append(params)
Expand Down
5 changes: 4 additions & 1 deletion mesmer/calibrate_mesmer/train_lv.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
Functions to train local variability module of MESMER.
"""


import numpy as np
import xarray as xr

from mesmer.io.save_mesmer_bundle import save_mesmer_data
Expand Down Expand Up @@ -237,6 +237,9 @@ def train_lv_AR1_sci(params_lv, targs, y, wgt_scen_eq, aux, cfg):
data = xr.DataArray(data, dims=("run", "time", "cell"))

params = _fit_auto_regression_xr(data, dim="time", lags=1)
# BUG/ TODO: we wrongfully average over the standard_deviation
# see https://github.com/MESMER-group/mesmer/issues/307
params["standard_deviation"] = np.sqrt(params.covariance)
params = params.mean("run")

params_scen.append(params)
Expand Down
11 changes: 6 additions & 5 deletions mesmer/stats/auto_regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,8 @@ def _fit_auto_regression_xr(data, dim, lags):
if not isinstance(data, xr.DataArray):
raise TypeError(f"Expected a `xr.DataArray`, got {type(data)}")

intercept, coeffs, std = xr.apply_ufunc(
# NOTE: this is slowish, see https://github.com/MESMER-group/mesmer/pull/290
intercept, coeffs, covariance = xr.apply_ufunc(
_fit_auto_regression_np,
data,
input_core_dims=[[dim]],
Expand All @@ -193,7 +194,7 @@ def _fit_auto_regression_xr(data, dim, lags):
data_vars = {
"intercept": intercept,
"coeffs": coeffs,
"standard_deviation": std,
"covariance": covariance,
"lags": lags,
}

Expand Down Expand Up @@ -229,7 +230,7 @@ def _fit_auto_regression_np(data, lags):
intercept = AR_result.params[0]
coeffs = AR_result.params[1:]

# standard deviation of the residuals
std = np.sqrt(AR_result.sigma2)
# covariance of the residuals
covariance = AR_result.sigma2

return intercept, coeffs, std
return intercept, coeffs, covariance
16 changes: 6 additions & 10 deletions tests/unit/test_auto_regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,7 @@ def test_fit_auto_regression_xr_1D_values():
{
"intercept": 1.04728995,
"coeffs": ("lags", [0.99682459]),
"standard_deviation": 1.02655342,
"covariance": 1.05381192,
"lags": [1],
}
)
Expand All @@ -219,7 +219,7 @@ def test_fit_auto_regression_xr_1D_values_lags():
{
"intercept": 2.08295035,
"coeffs": ("lags", [0.99318256]),
"standard_deviation": 1.08955374,
"covariance": 1.18712735,
"lags": [2],
}
)
Expand All @@ -238,16 +238,14 @@ def test_fit_auto_regression_xr_1D(lags):
_check_dataset_form(
res,
"_fit_auto_regression_result",
required_vars=["intercept", "coeffs", "standard_deviation"],
required_vars=["intercept", "coeffs", "covariance"],
)

_check_dataarray_form(res.intercept, "intercept", ndim=0, shape=())
_check_dataarray_form(
res.coeffs, "coeffs", ndim=1, required_dims={"lags"}, shape=(len(lags),)
)
_check_dataarray_form(
res.standard_deviation, "standard_deviation", ndim=0, shape=()
)
_check_dataarray_form(res.covariance, "covariance", ndim=0, shape=())

expected = xr.DataArray(lags, coords={"lags": lags})

Expand All @@ -265,7 +263,7 @@ def test_fit_auto_regression_xr_2D(lags):
_check_dataset_form(
res,
"_fit_auto_regression_result",
required_vars=["intercept", "coeffs", "standard_deviation"],
required_vars=["intercept", "coeffs", "covariance"],
)

_check_dataarray_form(res.intercept, "intercept", ndim=1, shape=(n_cells,))
Expand All @@ -276,9 +274,7 @@ def test_fit_auto_regression_xr_2D(lags):
required_dims={"cells", "lags"},
shape=(n_cells, lags),
)
_check_dataarray_form(
res.standard_deviation, "standard_deviation", ndim=1, shape=(n_cells,)
)
_check_dataarray_form(res.covariance, "covariance", ndim=1, shape=(n_cells,))


@pytest.mark.parametrize("lags", [1, 2])
Expand Down

0 comments on commit c31044f

Please sign in to comment.