Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

weighted: small improvements #4818

Merged
merged 3 commits into from
Jan 27, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 10 additions & 1 deletion xarray/core/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from html import escape
from textwrap import dedent
from typing import (
TYPE_CHECKING,
Any,
Callable,
Dict,
Expand Down Expand Up @@ -32,6 +33,12 @@
ALL_DIMS = ...


if TYPE_CHECKING:
from .dataarray import DataArray
from .weighted import Weighted

T_DataWithCoords = TypeVar("T_DataWithCoords", bound="DataWithCoords")

C = TypeVar("C")
T = TypeVar("T")

Expand Down Expand Up @@ -772,7 +779,9 @@ def groupby_bins(
},
)

def weighted(self, weights):
def weighted(
self: T_DataWithCoords, weights: "DataArray"
) -> "Weighted[T_DataWithCoords]":
"""
Weighted operations.

Expand Down
49 changes: 18 additions & 31 deletions xarray/core/weighted.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,16 @@
from typing import TYPE_CHECKING, Hashable, Iterable, Optional, Union, overload
from typing import TYPE_CHECKING, Generic, Hashable, Iterable, Optional, TypeVar, Union

from . import duck_array_ops
from .computation import dot
from .options import _get_keep_attrs
from .pycompat import is_duck_dask_array

if TYPE_CHECKING:
from .common import DataWithCoords # noqa: F401
from .dataarray import DataArray, Dataset

T_DataWithCoords = TypeVar("T_DataWithCoords", bound="DataWithCoords")


_WEIGHTED_REDUCE_DOCSTRING_TEMPLATE = """
Reduce this {cls}'s data by a weighted ``{fcn}`` along some dimension(s).

Expand Down Expand Up @@ -56,7 +59,7 @@
"""


class Weighted:
class Weighted(Generic[T_DataWithCoords]):
"""An object that implements weighted operations.

You should create a Weighted object by using the ``DataArray.weighted`` or
Expand All @@ -70,15 +73,7 @@ class Weighted:

__slots__ = ("obj", "weights")

@overload
def __init__(self, obj: "DataArray", weights: "DataArray") -> None:
...

@overload
def __init__(self, obj: "Dataset", weights: "DataArray") -> None:
...

def __init__(self, obj, weights):
def __init__(self, obj: T_DataWithCoords, weights: "DataArray"):
"""
Create a Weighted object

Expand Down Expand Up @@ -121,8 +116,8 @@ def _weight_check(w):
else:
_weight_check(weights.data)

self.obj = obj
self.weights = weights
self.obj: T_DataWithCoords = obj
self.weights: "DataArray" = weights

@staticmethod
def _reduce(
Expand All @@ -146,7 +141,6 @@ def _reduce(

# `dot` does not broadcast arrays, so this avoids creating a large
# DataArray (if `weights` has additional dimensions)
# maybe add fasttrack (`(da * weights).sum(dims=dim, skipna=skipna)`)
return dot(da, weights, dims=dim)

def _sum_of_weights(
Expand Down Expand Up @@ -203,7 +197,7 @@ def sum_of_weights(
self,
dim: Optional[Union[Hashable, Iterable[Hashable]]] = None,
keep_attrs: Optional[bool] = None,
) -> Union["DataArray", "Dataset"]:
) -> T_DataWithCoords:

return self._implementation(
self._sum_of_weights, dim=dim, keep_attrs=keep_attrs
Expand All @@ -214,7 +208,7 @@ def sum(
dim: Optional[Union[Hashable, Iterable[Hashable]]] = None,
skipna: Optional[bool] = None,
keep_attrs: Optional[bool] = None,
) -> Union["DataArray", "Dataset"]:
) -> T_DataWithCoords:

return self._implementation(
self._weighted_sum, dim=dim, skipna=skipna, keep_attrs=keep_attrs
Expand All @@ -225,7 +219,7 @@ def mean(
dim: Optional[Union[Hashable, Iterable[Hashable]]] = None,
skipna: Optional[bool] = None,
keep_attrs: Optional[bool] = None,
) -> Union["DataArray", "Dataset"]:
) -> T_DataWithCoords:

return self._implementation(
self._weighted_mean, dim=dim, skipna=skipna, keep_attrs=keep_attrs
Expand All @@ -239,22 +233,15 @@ def __repr__(self):
return f"{klass} with weights along dimensions: {weight_dims}"


class DataArrayWeighted(Weighted):
def _implementation(self, func, dim, **kwargs):

keep_attrs = kwargs.pop("keep_attrs")
if keep_attrs is None:
keep_attrs = _get_keep_attrs(default=False)

weighted = func(self.obj, dim=dim, **kwargs)

if keep_attrs:
weighted.attrs = self.obj.attrs
class DataArrayWeighted(Weighted["DataArray"]):
def _implementation(self, func, dim, **kwargs) -> "DataArray":

return weighted
dataset = self.obj._to_temp_dataset()
dataset = dataset.map(func, dim=dim, **kwargs)
return self.obj._from_temp_dataset(dataset)


class DatasetWeighted(Weighted):
class DatasetWeighted(Weighted["Dataset"]):
def _implementation(self, func, dim, **kwargs) -> "Dataset":

return self.obj.map(func, dim=dim, **kwargs)
Expand Down