From 3cc90c391d08ebba441ce8652a5bacb68d6d5825 Mon Sep 17 00:00:00 2001 From: Maximilian Roos Date: Wed, 25 May 2022 21:41:45 -0700 Subject: [PATCH] Improve perf of xr.corr --- xarray/core/computation.py | 14 ++++++++------ xarray/core/dataarray.py | 4 ++-- 2 files changed, 10 insertions(+), 8 deletions(-) diff --git a/xarray/core/computation.py b/xarray/core/computation.py index 81b5e3fd915..097f42394e4 100644 --- a/xarray/core/computation.py +++ b/xarray/core/computation.py @@ -22,6 +22,8 @@ import numpy as np +from xarray.core.types import T_DataArray + from . import dtypes, duck_array_ops, utils from .alignment import align, deep_align from .common import zeros_like @@ -1353,7 +1355,9 @@ def corr(da_a, da_b, dim=None): return _cov_corr(da_a, da_b, dim=dim, method="corr") -def _cov_corr(da_a, da_b, dim=None, ddof=0, method=None): +def _cov_corr( + da_a: T_DataArray, da_b: T_DataArray, dim=None, ddof=0, method=None +) -> T_DataArray: """ Internal method for xr.cov() and xr.corr() so only have to sanitize the input arrays once and we don't repeat code. @@ -1372,11 +1376,9 @@ def _cov_corr(da_a, da_b, dim=None, ddof=0, method=None): demeaned_da_b = da_b - da_b.mean(dim=dim) # 4. Compute covariance along the given dim - # N.B. `skipna=False` is required or there is a bug when computing - # auto-covariance. E.g. Try xr.cov(da,da) for - # da = xr.DataArray([[1, 2], [1, np.nan]], dims=["x", "time"]) - cov = (demeaned_da_a * demeaned_da_b).sum(dim=dim, skipna=True, min_count=1) / ( - valid_count + # `fillna` is required since `.dot` has no NaN tolerance. + cov = ( + demeaned_da_a.fillna(0.0).dot(demeaned_da_b.fillna(0.0), dims=dim) / valid_count ) if method == "cov": diff --git a/xarray/core/dataarray.py b/xarray/core/dataarray.py index 64c4e419788..81ee8bd1d3d 100644 --- a/xarray/core/dataarray.py +++ b/xarray/core/dataarray.py @@ -3399,8 +3399,8 @@ def imag(self) -> DataArray: return self._replace(self.variable.imag) def dot( - self, other: DataArray, dims: Hashable | Sequence[Hashable] | None = None - ) -> DataArray: + self, other: T_DataArray, dims: Hashable | Sequence[Hashable] | None = None + ) -> T_DataArray: """Perform dot product of two DataArrays along their shared dims. Equivalent to taking taking tensordot over all shared dims.