Skip to content

Commit

Permalink
fix AR_result ~co~variance (MESMER-group#317)
Browse files Browse the repository at this point in the history
  • Loading branch information
mathause committed Oct 4, 2023
1 parent ee3ea68 commit 679540a
Show file tree
Hide file tree
Showing 4 changed files with 18 additions and 19 deletions.
2 changes: 1 addition & 1 deletion mesmer/calibrate_mesmer/train_gv.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,7 @@ def train_gv_AR(params_gv, gv, max_lag, sel_crit):
data = xr.DataArray(data, dims=("run", "time"))

params = _fit_auto_regression_xr(data, dim="time", lags=AR_order_sel)
# BUG/ TODO: we wrongfully average over the standard_deviation
# BUG/ TODO: we wrongfully average over the standard deviation
# see https://github.com/MESMER-group/mesmer/issues/307
params["standard_deviation"] = np.sqrt(params.covariance)
params = params.mean("run")
Expand Down
2 changes: 1 addition & 1 deletion mesmer/calibrate_mesmer/train_lv.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,7 +237,7 @@ def train_lv_AR1_sci(params_lv, targs, y, wgt_scen_eq, aux, cfg):
data = xr.DataArray(data, dims=("run", "time", "cell"))

params = _fit_auto_regression_xr(data, dim="time", lags=1)
# BUG/ TODO: we wrongfully average over the standard_deviation
# BUG/ TODO: we wrongfully average over the standard deviation
# see https://github.com/MESMER-group/mesmer/issues/307
params["standard_deviation"] = np.sqrt(params.covariance)
params = params.mean("run")
Expand Down
17 changes: 8 additions & 9 deletions mesmer/stats/auto_regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,8 +119,6 @@ def _draw_auto_regression_correlated_np(
Notes
-----
The 'innovations' is the error or noise term.
As this is not a deterministic function it is not called `predict`. "Predicting"
an autoregressive process does not include the innovations and therefore asymptotes
towards a certain value (in contrast to this function).
Expand All @@ -137,6 +135,7 @@ def _draw_auto_regression_correlated_np(
# ensure reproducibility (TODO: clarify approach to this, see #35)
np.random.seed(seed)

# NOTE: 'innovations' is the error or noise term.
# innovations has shape (n_samples, n_ts + buffer, n_coeffs)
innovations = np.random.multivariate_normal(
mean=np.zeros(n_coeffs),
Expand Down Expand Up @@ -170,15 +169,15 @@ def _fit_auto_regression_xr(data, dim, lags):
Returns
-------
:obj:`xr.Dataset`
Dataset containing the estimated parameters of the ``intercept``, the AR ``coeffs``
and the ``standard_deviation`` of the residuals.
Dataset containing the estimated parameters of the ``intercept``, the AR
``coeffs`` and the ``variance`` of the residuals.
"""

if not isinstance(data, xr.DataArray):
raise TypeError(f"Expected a `xr.DataArray`, got {type(data)}")

# NOTE: this is slowish, see https://github.com/MESMER-group/mesmer/pull/290
intercept, coeffs, covariance = xr.apply_ufunc(
intercept, coeffs, variance = xr.apply_ufunc(
_fit_auto_regression_np,
data,
input_core_dims=[[dim]],
Expand All @@ -194,7 +193,7 @@ def _fit_auto_regression_xr(data, dim, lags):
data_vars = {
"intercept": intercept,
"coeffs": coeffs,
"covariance": covariance,
"variance": variance,
"lags": lags,
}

Expand Down Expand Up @@ -230,7 +229,7 @@ def _fit_auto_regression_np(data, lags):
intercept = AR_result.params[0]
coeffs = AR_result.params[1:]

# covariance of the residuals
covariance = AR_result.sigma2
# variance of the residuals
variance = AR_result.sigma2

return intercept, coeffs, covariance
return intercept, coeffs, variance
16 changes: 8 additions & 8 deletions tests/unit/test_auto_regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,12 +81,12 @@ def test_draw_auto_regression_correlated_np_shape(ar_order, n_cells, n_samples,

intercept = np.zeros(n_cells)
coefs = np.ones((ar_order, n_cells))
covariance = np.ones((n_cells, n_cells))
variance = np.ones((n_cells, n_cells))

result = mesmer.stats.auto_regression._draw_auto_regression_correlated_np(
intercept=intercept,
coeffs=coefs,
covariance=covariance,
covariance=variance,
n_samples=n_samples,
n_ts=n_ts,
seed=0,
Expand Down Expand Up @@ -198,7 +198,7 @@ def test_fit_auto_regression_xr_1D_values():
{
"intercept": 1.04728995,
"coeffs": ("lags", [0.99682459]),
"covariance": 1.05381192,
"variance": 1.05381192,
"lags": [1],
}
)
Expand All @@ -219,7 +219,7 @@ def test_fit_auto_regression_xr_1D_values_lags():
{
"intercept": 2.08295035,
"coeffs": ("lags", [0.99318256]),
"covariance": 1.18712735,
"variance": 1.18712735,
"lags": [2],
}
)
Expand All @@ -238,14 +238,14 @@ def test_fit_auto_regression_xr_1D(lags):
_check_dataset_form(
res,
"_fit_auto_regression_result",
required_vars=["intercept", "coeffs", "covariance"],
required_vars=["intercept", "coeffs", "variance"],
)

_check_dataarray_form(res.intercept, "intercept", ndim=0, shape=())
_check_dataarray_form(
res.coeffs, "coeffs", ndim=1, required_dims={"lags"}, shape=(len(lags),)
)
_check_dataarray_form(res.covariance, "covariance", ndim=0, shape=())
_check_dataarray_form(res.variance, "variance", ndim=0, shape=())

expected = xr.DataArray(lags, coords={"lags": lags})

Expand All @@ -263,7 +263,7 @@ def test_fit_auto_regression_xr_2D(lags):
_check_dataset_form(
res,
"_fit_auto_regression_result",
required_vars=["intercept", "coeffs", "covariance"],
required_vars=["intercept", "coeffs", "variance"],
)

_check_dataarray_form(res.intercept, "intercept", ndim=1, shape=(n_cells,))
Expand All @@ -274,7 +274,7 @@ def test_fit_auto_regression_xr_2D(lags):
required_dims={"cells", "lags"},
shape=(n_cells, lags),
)
_check_dataarray_form(res.covariance, "covariance", ndim=1, shape=(n_cells,))
_check_dataarray_form(res.variance, "variance", ndim=1, shape=(n_cells,))


@pytest.mark.parametrize("lags", [1, 2])
Expand Down

0 comments on commit 679540a

Please sign in to comment.