From bd78b7f1f46a0fb0a0b0f2d4f4bdbacba55be93d Mon Sep 17 00:00:00 2001 From: Dan Nowacki Date: Thu, 16 May 2019 08:28:29 -0700 Subject: [PATCH 01/31] Implement load_dataset() and load_dataarray() (#2917) * Partial fix for #2841 to improve formatting. Updates formatting to use .format() instead of % operator. Changed all instances of % to .format() and added test for using tuple as key, which errored using % operator. * Revert "Partial fix for #2841 to improve formatting." This reverts commit f17f3ad1a4a2069cd70385af8ad331f644ec66ba. * Implement load_dataset() and load_dataarray() BUG: Fixes #2887 by adding @shoyer solution for load_dataset and load_dataarray, wrappers around open_dataset and open_dataarray which open, load, and close the file and return the Dataset/DataArray TST: Add tests for sequentially opening and writing to files using new functions DOC: Add to whats-new.rst. Also a tiny change to the open_dataset docstring Update docstrings and check for cache in kwargs Undeprecate load_dataset Add to api.rst, fix whats-new.rst typo, raise error instead of warning --- doc/api.rst | 2 ++ doc/whats-new.rst | 12 ++++++-- xarray/__init__.py | 2 +- xarray/backends/api.py | 57 +++++++++++++++++++++++++++++++++-- xarray/tests/test_backends.py | 19 +++++++++++- xarray/tutorial.py | 15 +++------ 6 files changed, 90 insertions(+), 17 deletions(-) diff --git a/doc/api.rst b/doc/api.rst index 00b33959eed..0e766f2cf9a 100644 --- a/doc/api.rst +++ b/doc/api.rst @@ -460,6 +460,7 @@ Dataset methods :toctree: generated/ open_dataset + load_dataset open_mfdataset open_rasterio open_zarr @@ -487,6 +488,7 @@ DataArray methods :toctree: generated/ open_dataarray + load_dataarray DataArray.to_dataset DataArray.to_netcdf DataArray.to_pandas diff --git a/doc/whats-new.rst b/doc/whats-new.rst index ac1b5269bfa..d904a3814f1 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -29,6 +29,12 @@ Enhancements By `James McCreight `_. - Clean up Python 2 compatibility in code (:issue:`2950`) By `Guido Imperiale `_. +- Implement ``load_dataset()`` and ``load_dataarray()`` as alternatives to + ``open_dataset()`` and ``open_dataarray()`` to open, load into memory, + and close files, returning the Dataset or DataArray. These functions are + helpful for avoiding file-lock errors when trying to write to files opened + using ``open_dataset()`` or ``open_dataarray()``. (:issue:`2887`) + By `Dan Nowacki `_. Bug fixes ~~~~~~~~~ @@ -153,9 +159,9 @@ Other enhancements By `Keisuke Fujii `_. - Added :py:meth:`~xarray.Dataset.drop_dims` (:issue:`1949`). By `Kevin Squire `_. -- ``xr.open_zarr`` now accepts manually specified chunks with the ``chunks=`` - parameter. ``auto_chunk=True`` is equivalent to ``chunks='auto'`` for - backwards compatibility. The ``overwrite_encoded_chunks`` parameter is +- ``xr.open_zarr`` now accepts manually specified chunks with the ``chunks=`` + parameter. ``auto_chunk=True`` is equivalent to ``chunks='auto'`` for + backwards compatibility. The ``overwrite_encoded_chunks`` parameter is added to remove the original zarr chunk encoding. By `Lily Wang `_. diff --git a/xarray/__init__.py b/xarray/__init__.py index 773dfe19d01..506cb46de26 100644 --- a/xarray/__init__.py +++ b/xarray/__init__.py @@ -17,7 +17,7 @@ from .core.options import set_options from .backends.api import (open_dataset, open_dataarray, open_mfdataset, - save_mfdataset) + save_mfdataset, load_dataset, load_dataarray) from .backends.rasterio_ import open_rasterio from .backends.zarr import open_zarr diff --git a/xarray/backends/api.py b/xarray/backends/api.py index 7c5040580fe..01188e92752 100644 --- a/xarray/backends/api.py +++ b/xarray/backends/api.py @@ -185,12 +185,64 @@ def _finalize_store(write, store): store.close() +def load_dataset(filename_or_obj, **kwargs): + """Open, load into memory, and close a Dataset from a file or file-like + object. + + This is a thin wrapper around :py:meth:`~xarray.open_dataset`. It differs + from `open_dataset` in that it loads the Dataset into memory, closes the + file, and returns the Dataset. In contrast, `open_dataset` keeps the file + handle open and lazy loads its contents. All parameters are passed directly + to `open_dataset`. See that documentation for further details. + + Returns + ------- + dataset : Dataset + The newly created Dataset. + + See Also + -------- + open_dataset + """ + if 'cache' in kwargs: + raise TypeError('cache has no effect in this context') + + with open_dataset(filename_or_obj, **kwargs) as ds: + return ds.load() + + +def load_dataarray(filename_or_obj, **kwargs): + """Open, load into memory, and close a DataArray from a file or file-like + object containing a single data variable. + + This is a thin wrapper around :py:meth:`~xarray.open_dataarray`. It differs + from `open_dataarray` in that it loads the Dataset into memory, closes the + file, and returns the Dataset. In contrast, `open_dataarray` keeps the file + handle open and lazy loads its contents. All parameters are passed directly + to `open_dataarray`. See that documentation for further details. + + Returns + ------- + datarray : DataArray + The newly created DataArray. + + See Also + -------- + open_dataarray + """ + if 'cache' in kwargs: + raise TypeError('cache has no effect in this context') + + with open_dataarray(filename_or_obj, **kwargs) as da: + return da.load() + + def open_dataset(filename_or_obj, group=None, decode_cf=True, mask_and_scale=None, decode_times=True, autoclose=None, concat_characters=True, decode_coords=True, engine=None, chunks=None, lock=None, cache=None, drop_variables=None, backend_kwargs=None, use_cftime=None): - """Load and decode a dataset from a file or file-like object. + """Open and decode a dataset from a file or file-like object. Parameters ---------- @@ -406,7 +458,8 @@ def open_dataarray(filename_or_obj, group=None, decode_cf=True, concat_characters=True, decode_coords=True, engine=None, chunks=None, lock=None, cache=None, drop_variables=None, backend_kwargs=None, use_cftime=None): - """Open an DataArray from a netCDF file containing a single data variable. + """Open an DataArray from a file or file-like object containing a single + data variable. This is designed to read netCDF files with only one data variable. If multiple variables are present then a ValueError is raised. diff --git a/xarray/tests/test_backends.py b/xarray/tests/test_backends.py index a4c0374e158..f31d3bf4f9b 100644 --- a/xarray/tests/test_backends.py +++ b/xarray/tests/test_backends.py @@ -19,7 +19,7 @@ import xarray as xr from xarray import ( DataArray, Dataset, backends, open_dataarray, open_dataset, open_mfdataset, - save_mfdataset) + save_mfdataset, load_dataset, load_dataarray) from xarray.backends.common import robust_getitem from xarray.backends.netCDF4_ import _extract_nc4_variable_encoding from xarray.backends.pydap_ import PydapDataStore @@ -2641,6 +2641,23 @@ def test_save_mfdataset_compute_false_roundtrip(self): with open_mfdataset([tmp1, tmp2]) as actual: assert_identical(actual, original) + def test_load_dataset(self): + with create_tmp_file() as tmp: + original = Dataset({'foo': ('x', np.random.randn(10))}) + original.to_netcdf(tmp) + ds = load_dataset(tmp) + # this would fail if we used open_dataset instead of load_dataset + ds.to_netcdf(tmp) + + def test_load_dataarray(self): + with create_tmp_file() as tmp: + original = Dataset({'foo': ('x', np.random.randn(10))}) + original.to_netcdf(tmp) + ds = load_dataarray(tmp) + # this would fail if we used open_dataarray instead of + # load_dataarray + ds.to_netcdf(tmp) + @requires_scipy_or_netCDF4 @requires_pydap diff --git a/xarray/tutorial.py b/xarray/tutorial.py index f54cf7b3889..1a977450ed6 100644 --- a/xarray/tutorial.py +++ b/xarray/tutorial.py @@ -27,7 +27,7 @@ def open_dataset(name, cache=True, cache_dir=_default_cache_dir, github_url='https://github.com/pydata/xarray-data', branch='master', **kws): """ - Load a dataset from the online repository (requires internet). + Open a dataset from the online repository (requires internet). If a local copy is found then always use that to avoid network traffic. @@ -91,17 +91,12 @@ def open_dataset(name, cache=True, cache_dir=_default_cache_dir, def load_dataset(*args, **kwargs): """ - `load_dataset` will be removed a future version of xarray. The current - behavior of this function can be achived by using - `tutorial.open_dataset(...).load()`. + Open, load into memory, and close a dataset from the online repository + (requires internet). See Also -------- open_dataset """ - warnings.warn( - "load_dataset` will be removed in a future version of xarray. The " - "current behavior of this function can be achived by using " - "`tutorial.open_dataset(...).load()`.", - DeprecationWarning, stacklevel=2) - return open_dataset(*args, **kwargs).load() + with open_dataset(*args, **kwargs) as ds: + return ds.load() From 66581084a89f75476b581ef74e5226eae2d62a84 Mon Sep 17 00:00:00 2001 From: Kevin Squire Date: Thu, 16 May 2019 08:55:15 -0700 Subject: [PATCH 02/31] Fix rolling.constuct() example (#2967) * The example was using the wrong name for the function (to_datarray), and used the wrong dimension for the window --- xarray/core/rolling.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/xarray/core/rolling.py b/xarray/core/rolling.py index ad9b17fef92..c113cfebe2a 100644 --- a/xarray/core/rolling.py +++ b/xarray/core/rolling.py @@ -170,15 +170,15 @@ def construct(self, window_dim, stride=1, fill_value=dtypes.NA): -------- >>> da = DataArray(np.arange(8).reshape(2, 4), dims=('a', 'b')) >>> - >>> rolling = da.rolling(a=3) - >>> rolling.to_datarray('window_dim') + >>> rolling = da.rolling(b=3) + >>> rolling.construct('window_dim') array([[[np.nan, np.nan, 0], [np.nan, 0, 1], [0, 1, 2], [1, 2, 3]], [[np.nan, np.nan, 4], [np.nan, 4, 5], [4, 5, 6], [5, 6, 7]]]) Dimensions without coordinates: a, b, window_dim >>> - >>> rolling = da.rolling(a=3, center=True) - >>> rolling.to_datarray('window_dim') + >>> rolling = da.rolling(b=3, center=True) + >>> rolling.construct('window_dim') array([[[np.nan, 0, 1], [0, 1, 2], [1, 2, 3], [2, 3, np.nan]], [[np.nan, 4, 5], [4, 5, 6], [5, 6, 7], [6, 7, np.nan]]]) From 0811141e8f985a1f3b95ead92c3850cc74e160a5 Mon Sep 17 00:00:00 2001 From: Peter Hausamann Date: Tue, 21 May 2019 19:37:54 +0200 Subject: [PATCH 03/31] Add transpose_coords option to DataArray.transpose (#2556) * Add transpose_coords option to DataArray.transpose Fixes #1856 * Fix typo * Fix bug in transpose Fix python 2 compatibility * Set default for transpose_coords to None Update documentation * Fix bug in coordinate tranpose Update documentation * Suppress FutureWarning in tests * Add restore_coord_dims parameter to DataArrayGroupBy.apply * Move restore_coord_dims parameter to GroupBy class * Remove restore_coord_dims parameter from DataArrayResample.apply * Update whats-new * Update whats-new --- doc/whats-new.rst | 7 +++++ xarray/core/common.py | 27 ++++++++++++---- xarray/core/dataarray.py | 26 ++++++++++++++-- xarray/core/groupby.py | 25 ++++++++++++--- xarray/plot/plot.py | 14 ++++++--- xarray/tests/test_dataarray.py | 57 ++++++++++++++++++++++++++++------ xarray/tests/test_dataset.py | 12 +++++-- xarray/tests/test_interp.py | 6 ++-- xarray/tests/test_plot.py | 3 +- 9 files changed, 144 insertions(+), 33 deletions(-) diff --git a/doc/whats-new.rst b/doc/whats-new.rst index d904a3814f1..ab7c155950c 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -27,6 +27,13 @@ Enhancements - Character arrays' character dimension name decoding and encoding handled by ``var.encoding['char_dim_name']`` (:issue:`2895`) By `James McCreight `_. +- :py:meth:`DataArray.transpose` now accepts a keyword argument + ``transpose_coords`` which enables transposition of coordinates in the + same way as :py:meth:`Dataset.transpose`. :py:meth:`DataArray.groupby` + :py:meth:`DataArray.groupby_bins`, and :py:meth:`DataArray.resample` now + accept a keyword argument ``restore_coord_dims`` which keeps the order + of the dimensions of multi-dimensional coordinates intact (:issue:`1856`). + By `Peter Hausamann `_. - Clean up Python 2 compatibility in code (:issue:`2950`) By `Guido Imperiale `_. - Implement ``load_dataset()`` and ``load_dataarray()`` as alternatives to diff --git a/xarray/core/common.py b/xarray/core/common.py index b518e8431fd..00d0383a727 100644 --- a/xarray/core/common.py +++ b/xarray/core/common.py @@ -441,7 +441,8 @@ def pipe(self, func: Union[Callable[..., T], Tuple[Callable[..., T], str]], else: return func(self, *args, **kwargs) - def groupby(self, group, squeeze: bool = True): + def groupby(self, group, squeeze: bool = True, + restore_coord_dims: Optional[bool] = None): """Returns a GroupBy object for performing grouped operations. Parameters @@ -453,6 +454,9 @@ def groupby(self, group, squeeze: bool = True): If "group" is a dimension of any arrays in this dataset, `squeeze` controls whether the subarrays have a dimension of length 1 along that dimension or if the dimension is squeezed out. + restore_coord_dims : bool, optional + If True, also restore the dimension order of multi-dimensional + coordinates. Returns ------- @@ -485,11 +489,13 @@ def groupby(self, group, squeeze: bool = True): core.groupby.DataArrayGroupBy core.groupby.DatasetGroupBy """ # noqa - return self._groupby_cls(self, group, squeeze=squeeze) + return self._groupby_cls(self, group, squeeze=squeeze, + restore_coord_dims=restore_coord_dims) def groupby_bins(self, group, bins, right: bool = True, labels=None, precision: int = 3, include_lowest: bool = False, - squeeze: bool = True): + squeeze: bool = True, + restore_coord_dims: Optional[bool] = None): """Returns a GroupBy object for performing grouped operations. Rather than using all unique values of `group`, the values are discretized @@ -522,6 +528,9 @@ def groupby_bins(self, group, bins, right: bool = True, labels=None, If "group" is a dimension of any arrays in this dataset, `squeeze` controls whether the subarrays have a dimension of length 1 along that dimension or if the dimension is squeezed out. + restore_coord_dims : bool, optional + If True, also restore the dimension order of multi-dimensional + coordinates. Returns ------- @@ -536,9 +545,11 @@ def groupby_bins(self, group, bins, right: bool = True, labels=None, .. [1] http://pandas.pydata.org/pandas-docs/stable/generated/pandas.cut.html """ # noqa return self._groupby_cls(self, group, squeeze=squeeze, bins=bins, + restore_coord_dims=restore_coord_dims, cut_kwargs={'right': right, 'labels': labels, 'precision': precision, - 'include_lowest': include_lowest}) + 'include_lowest': + include_lowest}) def rolling(self, dim: Optional[Mapping[Hashable, int]] = None, min_periods: Optional[int] = None, center: bool = False, @@ -669,7 +680,7 @@ def resample(self, indexer: Optional[Mapping[Hashable, str]] = None, skipna=None, closed: Optional[str] = None, label: Optional[str] = None, base: int = 0, keep_attrs: Optional[bool] = None, - loffset=None, + loffset=None, restore_coord_dims: Optional[bool] = None, **indexer_kwargs: str): """Returns a Resample object for performing resampling operations. @@ -697,6 +708,9 @@ def resample(self, indexer: Optional[Mapping[Hashable, str]] = None, If True, the object's attributes (`attrs`) will be copied from the original object to the new one. If False (default), the new object will be returned without attributes. + restore_coord_dims : bool, optional + If True, also restore the dimension order of multi-dimensional + coordinates. **indexer_kwargs : {dim: freq} The keyword arguments form of ``indexer``. One of indexer or indexer_kwargs must be provided. @@ -786,7 +800,8 @@ def resample(self, indexer: Optional[Mapping[Hashable, str]] = None, dims=dim_coord.dims, name=RESAMPLE_DIM) resampler = self._resample_cls(self, group=group, dim=dim_name, grouper=grouper, - resample_dim=RESAMPLE_DIM) + resample_dim=RESAMPLE_DIM, + restore_coord_dims=restore_coord_dims) return resampler diff --git a/xarray/core/dataarray.py b/xarray/core/dataarray.py index 8d3836f5d8c..15abdaf4a92 100644 --- a/xarray/core/dataarray.py +++ b/xarray/core/dataarray.py @@ -1405,7 +1405,7 @@ def unstack(self, dim=None): ds = self._to_temp_dataset().unstack(dim) return self._from_temp_dataset(ds) - def transpose(self, *dims) -> 'DataArray': + def transpose(self, *dims, transpose_coords=None) -> 'DataArray': """Return a new DataArray object with transposed dimensions. Parameters @@ -1413,6 +1413,8 @@ def transpose(self, *dims) -> 'DataArray': *dims : str, optional By default, reverse the dimensions. Otherwise, reorder the dimensions to this order. + transpose_coords : boolean, optional + If True, also transpose the coordinates of this DataArray. Returns ------- @@ -1430,8 +1432,28 @@ def transpose(self, *dims) -> 'DataArray': numpy.transpose Dataset.transpose """ + if dims: + if set(dims) ^ set(self.dims): + raise ValueError('arguments to transpose (%s) must be ' + 'permuted array dimensions (%s)' + % (dims, tuple(self.dims))) + variable = self.variable.transpose(*dims) - return self._replace(variable) + if transpose_coords: + coords = {} + for name, coord in self.coords.items(): + coord_dims = tuple(dim for dim in dims if dim in coord.dims) + coords[name] = coord.variable.transpose(*coord_dims) + return self._replace(variable, coords) + else: + if transpose_coords is None \ + and any(self[c].ndim > 1 for c in self.coords): + warnings.warn('This DataArray contains multi-dimensional ' + 'coordinates. In the future, these coordinates ' + 'will be transposed as well unless you specify ' + 'transpose_coords=False.', + FutureWarning, stacklevel=2) + return self._replace(variable) @property def T(self) -> 'DataArray': diff --git a/xarray/core/groupby.py b/xarray/core/groupby.py index 82a92044caf..d7dcb5b0426 100644 --- a/xarray/core/groupby.py +++ b/xarray/core/groupby.py @@ -197,7 +197,7 @@ class GroupBy(SupportsArithmetic): """ def __init__(self, obj, group, squeeze=False, grouper=None, bins=None, - cut_kwargs={}): + restore_coord_dims=None, cut_kwargs={}): """Create a GroupBy object Parameters @@ -215,6 +215,9 @@ def __init__(self, obj, group, squeeze=False, grouper=None, bins=None, bins : array-like, optional If `bins` is specified, the groups will be discretized into the specified bins by `pandas.cut`. + restore_coord_dims : bool, optional + If True, also restore the dimension order of multi-dimensional + coordinates. cut_kwargs : dict, optional Extra keyword arguments to pass to `pandas.cut` @@ -279,6 +282,16 @@ def __init__(self, obj, group, squeeze=False, grouper=None, bins=None, safe_cast_to_index(group), sort=(bins is None)) unique_coord = IndexVariable(group.name, unique_values) + if isinstance(obj, DataArray) \ + and restore_coord_dims is None \ + and any(obj[c].ndim > 1 for c in obj.coords): + warnings.warn('This DataArray contains multi-dimensional ' + 'coordinates. In the future, the dimension order ' + 'of these coordinates will be restored as well ' + 'unless you specify restore_coord_dims=False.', + FutureWarning, stacklevel=2) + restore_coord_dims = False + # specification for the groupby operation self._obj = obj self._group = group @@ -288,6 +301,7 @@ def __init__(self, obj, group, squeeze=False, grouper=None, bins=None, self._stacked_dim = stacked_dim self._inserted_dims = inserted_dims self._full_index = full_index + self._restore_coord_dims = restore_coord_dims # cached attributes self._groups = None @@ -508,7 +522,8 @@ def lookup_order(dimension): return axis new_order = sorted(stacked.dims, key=lookup_order) - return stacked.transpose(*new_order) + return stacked.transpose( + *new_order, transpose_coords=self._restore_coord_dims) def apply(self, func, shortcut=False, args=(), **kwargs): """Apply a function over each array in the group and concatenate them @@ -558,7 +573,7 @@ def apply(self, func, shortcut=False, args=(), **kwargs): for arr in grouped) return self._combine(applied, shortcut=shortcut) - def _combine(self, applied, shortcut=False): + def _combine(self, applied, restore_coord_dims=False, shortcut=False): """Recombine the applied objects like the original.""" applied_example, applied = peek_at(applied) coord, dim, positions = self._infer_concat_args(applied_example) @@ -580,8 +595,8 @@ def _combine(self, applied, shortcut=False): combined = self._maybe_unstack(combined) return combined - def reduce(self, func, dim=None, axis=None, - keep_attrs=None, shortcut=True, **kwargs): + def reduce(self, func, dim=None, axis=None, keep_attrs=None, + shortcut=True, **kwargs): """Reduce the items in this group by applying `func` along some dimension(s). diff --git a/xarray/plot/plot.py b/xarray/plot/plot.py index 316d4fb4dd9..d4cb1a7726b 100644 --- a/xarray/plot/plot.py +++ b/xarray/plot/plot.py @@ -64,8 +64,10 @@ def _infer_line_data(darray, x, y, hue): if huename in darray.dims: otherindex = 1 if darray.dims.index(huename) == 0 else 0 otherdim = darray.dims[otherindex] - yplt = darray.transpose(otherdim, huename) - xplt = xplt.transpose(otherdim, huename) + yplt = darray.transpose( + otherdim, huename, transpose_coords=False) + xplt = xplt.transpose( + otherdim, huename, transpose_coords=False) else: raise ValueError('For 2D inputs, hue must be a dimension' + ' i.e. one of ' + repr(darray.dims)) @@ -79,7 +81,9 @@ def _infer_line_data(darray, x, y, hue): if yplt.ndim > 1: if huename in darray.dims: otherindex = 1 if darray.dims.index(huename) == 0 else 0 - xplt = darray.transpose(otherdim, huename) + otherdim = darray.dims[otherindex] + xplt = darray.transpose( + otherdim, huename, transpose_coords=False) else: raise ValueError('For 2D inputs, hue must be a dimension' + ' i.e. one of ' + repr(darray.dims)) @@ -614,9 +618,9 @@ def newplotfunc(darray, x=None, y=None, figsize=None, size=None, yx_dims = (ylab, xlab) dims = yx_dims + tuple(d for d in darray.dims if d not in yx_dims) if dims != darray.dims: - darray = darray.transpose(*dims) + darray = darray.transpose(*dims, transpose_coords=True) elif darray[xlab].dims[-1] == darray.dims[0]: - darray = darray.transpose() + darray = darray.transpose(transpose_coords=True) # Pass the data as a masked ndarray too zval = darray.to_masked_array(copy=False) diff --git a/xarray/tests/test_dataarray.py b/xarray/tests/test_dataarray.py index 9471ec144c0..43af27d0696 100644 --- a/xarray/tests/test_dataarray.py +++ b/xarray/tests/test_dataarray.py @@ -1681,14 +1681,14 @@ def test_math_with_coords(self): assert_identical(expected, actual) actual = orig[0, :] + orig[:, 0] - assert_identical(expected.T, actual) + assert_identical(expected.transpose(transpose_coords=True), actual) - actual = orig - orig.T + actual = orig - orig.transpose(transpose_coords=True) expected = DataArray(np.zeros((2, 3)), orig.coords) assert_identical(expected, actual) - actual = orig.T - orig - assert_identical(expected.T, actual) + actual = orig.transpose(transpose_coords=True) - orig + assert_identical(expected.transpose(transpose_coords=True), actual) alt = DataArray([1, 1], {'x': [-1, -2], 'c': 'foo', 'd': 555}, 'x') actual = orig + alt @@ -1801,8 +1801,27 @@ def test_stack_nonunique_consistency(self): assert_identical(expected, actual) def test_transpose(self): - assert_equal(self.dv.variable.transpose(), - self.dv.transpose().variable) + da = DataArray(np.random.randn(3, 4, 5), dims=('x', 'y', 'z'), + coords={'x': range(3), 'y': range(4), 'z': range(5), + 'xy': (('x', 'y'), np.random.randn(3, 4))}) + + actual = da.transpose(transpose_coords=False) + expected = DataArray(da.values.T, dims=('z', 'y', 'x'), + coords=da.coords) + assert_equal(expected, actual) + + actual = da.transpose('z', 'y', 'x', transpose_coords=True) + expected = DataArray(da.values.T, dims=('z', 'y', 'x'), + coords={'x': da.x.values, 'y': da.y.values, + 'z': da.z.values, + 'xy': (('y', 'x'), da.xy.values.T)}) + assert_equal(expected, actual) + + with pytest.raises(ValueError): + da.transpose('x', 'y') + + with pytest.warns(FutureWarning): + da.transpose() def test_squeeze(self): assert_equal(self.dv.variable.squeeze(), self.dv.squeeze().variable) @@ -2258,6 +2277,23 @@ def test_groupby_restore_dim_order(self): result = array.groupby(by).apply(lambda x: x.squeeze()) assert result.dims == expected_dims + def test_groupby_restore_coord_dims(self): + array = DataArray(np.random.randn(5, 3), + coords={'a': ('x', range(5)), 'b': ('y', range(3)), + 'c': (('x', 'y'), np.random.randn(5, 3))}, + dims=['x', 'y']) + + for by, expected_dims in [('x', ('x', 'y')), + ('y', ('x', 'y')), + ('a', ('a', 'y')), + ('b', ('x', 'b'))]: + result = array.groupby(by, restore_coord_dims=True).apply( + lambda x: x.squeeze())['c'] + assert result.dims == expected_dims + + with pytest.warns(FutureWarning): + array.groupby('x').apply(lambda x: x.squeeze()) + def test_groupby_first_and_last(self): array = DataArray([1, 2, 3, 4, 5], dims='x') by = DataArray(['a'] * 2 + ['b'] * 3, dims='x', name='ab') @@ -2445,15 +2481,18 @@ def test_resample_drop_nondim_coords(self): array = ds['data'] # Re-sample - actual = array.resample(time="12H").mean('time') + actual = array.resample( + time="12H", restore_coord_dims=True).mean('time') assert 'tc' not in actual.coords # Up-sample - filling - actual = array.resample(time="1H").ffill() + actual = array.resample( + time="1H", restore_coord_dims=True).ffill() assert 'tc' not in actual.coords # Up-sample - interpolation - actual = array.resample(time="1H").interpolate('linear') + actual = array.resample( + time="1H", restore_coord_dims=True).interpolate('linear') assert 'tc' not in actual.coords def test_resample_keep_attrs(self): diff --git a/xarray/tests/test_dataset.py b/xarray/tests/test_dataset.py index b47e26328ad..ecacf43caf4 100644 --- a/xarray/tests/test_dataset.py +++ b/xarray/tests/test_dataset.py @@ -4062,14 +4062,20 @@ def test_dataset_math_errors(self): def test_dataset_transpose(self): ds = Dataset({'a': (('x', 'y'), np.random.randn(3, 4)), - 'b': (('y', 'x'), np.random.randn(4, 3))}) + 'b': (('y', 'x'), np.random.randn(4, 3))}, + coords={'x': range(3), 'y': range(4), + 'xy': (('x', 'y'), np.random.randn(3, 4))}) actual = ds.transpose() - expected = ds.apply(lambda x: x.transpose()) + expected = Dataset({'a': (('y', 'x'), ds.a.values.T), + 'b': (('x', 'y'), ds.b.values.T)}, + coords={'x': ds.x.values, 'y': ds.y.values, + 'xy': (('y', 'x'), ds.xy.values.T)}) assert_identical(expected, actual) actual = ds.transpose('x', 'y') - expected = ds.apply(lambda x: x.transpose('x', 'y')) + expected = ds.apply( + lambda x: x.transpose('x', 'y', transpose_coords=True)) assert_identical(expected, actual) ds = create_test_data() diff --git a/xarray/tests/test_interp.py b/xarray/tests/test_interp.py index 8347d54bd1e..a11e4b9e79a 100644 --- a/xarray/tests/test_interp.py +++ b/xarray/tests/test_interp.py @@ -143,7 +143,8 @@ def func(obj, dim, new_x): 'y': da['y'], 'x': ('z', xdest.values), 'x2': ('z', func(da['x2'], 'x', xdest))}) - assert_allclose(actual, expected.transpose('z', 'y')) + assert_allclose(actual, + expected.transpose('z', 'y', transpose_coords=True)) # xdest is 2d xdest = xr.DataArray(np.linspace(0.1, 0.9, 30).reshape(6, 5), @@ -160,7 +161,8 @@ def func(obj, dim, new_x): coords={'z': xdest['z'], 'w': xdest['w'], 'z2': xdest['z2'], 'y': da['y'], 'x': (('z', 'w'), xdest), 'x2': (('z', 'w'), func(da['x2'], 'x', xdest))}) - assert_allclose(actual, expected.transpose('z', 'w', 'y')) + assert_allclose(actual, + expected.transpose('z', 'w', 'y', transpose_coords=True)) @pytest.mark.parametrize('case', [3, 4]) diff --git a/xarray/tests/test_plot.py b/xarray/tests/test_plot.py index 759a2974ca6..84510da65fe 100644 --- a/xarray/tests/test_plot.py +++ b/xarray/tests/test_plot.py @@ -1202,7 +1202,8 @@ def test_cmap_and_color_both(self): def test_2d_coord_with_interval(self): for dim in self.darray.dims: - gp = self.darray.groupby_bins(dim, range(15)).mean(dim) + gp = self.darray.groupby_bins( + dim, range(15), restore_coord_dims=True).mean(dim) for kind in ['imshow', 'pcolormesh', 'contourf', 'contour']: getattr(gp.plot, kind)() From 7edf2e20d4c898fbb637da3b3e6ded15808e040b Mon Sep 17 00:00:00 2001 From: Maximilian Roos <5635139+max-sixty@users.noreply.github.com> Date: Fri, 24 May 2019 22:01:31 -0400 Subject: [PATCH 04/31] Remove deprecated pytest.config usages (#2988) * initial attempt at moving away from deprecated pytest.config * whatsnew --- conftest.py | 20 ++++++++++++++++++++ doc/whats-new.rst | 2 ++ xarray/tests/__init__.py | 19 ++----------------- 3 files changed, 24 insertions(+), 17 deletions(-) diff --git a/conftest.py b/conftest.py index d7f4e0c89bc..ffceb27e753 100644 --- a/conftest.py +++ b/conftest.py @@ -1,5 +1,7 @@ """Configuration for pytest.""" +import pytest + def pytest_addoption(parser): """Add command-line flags for pytest.""" @@ -7,3 +9,21 @@ def pytest_addoption(parser): help="runs flaky tests") parser.addoption("--run-network-tests", action="store_true", help="runs tests requiring a network connection") + + +def pytest_collection_modifyitems(config, items): + + if not config.getoption("--run-flaky"): + skip_flaky = pytest.mark.skip( + reason="set --run-flaky option to run flaky tests") + for item in items: + if "flaky" in item.keywords: + item.add_marker(skip_flaky) + + if not config.getoption("--run-network-tests"): + skip_network = pytest.mark.skip( + reason="set --run-network-tests option to run tests requiring an" + "internet connection") + for item in items: + if "network" in item.keywords: + item.add_marker(skip_network) diff --git a/doc/whats-new.rst b/doc/whats-new.rst index ab7c155950c..dfdca55d218 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -54,6 +54,8 @@ Bug fixes By `Deepak Cherian `_. +- Removed usages of `pytest.config`, which is deprecated (:issue:`2988`:) + By `Maximilian Roos `_. .. _whats-new.0.12.1: diff --git a/xarray/tests/__init__.py b/xarray/tests/__init__.py index e9d670e4dd9..5e559fce526 100644 --- a/xarray/tests/__init__.py +++ b/xarray/tests/__init__.py @@ -108,23 +108,8 @@ def LooseVersion(vstring): else: dask.config.set(scheduler='single-threaded') -# pytest config -try: - _SKIP_FLAKY = not pytest.config.getoption("--run-flaky") - _SKIP_NETWORK_TESTS = not pytest.config.getoption("--run-network-tests") -except (ValueError, AttributeError): - # Can't get config from pytest, e.g., because xarray is installed instead - # of being run from a development version (and hence conftests.py is not - # available). Don't run flaky tests. - _SKIP_FLAKY = True - _SKIP_NETWORK_TESTS = True - -flaky = pytest.mark.skipif( - _SKIP_FLAKY, reason="set --run-flaky option to run flaky tests") -network = pytest.mark.skipif( - _SKIP_NETWORK_TESTS, - reason="set --run-network-tests option to run tests requiring an " - "internet connection") +flaky = pytest.mark.flaky +network = pytest.mark.network @contextmanager From 6dc8b60849fab48f24494859c15a42f078025be6 Mon Sep 17 00:00:00 2001 From: Zach Griffith Date: Sun, 26 May 2019 19:20:54 -0500 Subject: [PATCH 05/31] Add fill_value for concat and auto_combine (#2964) * add fill_value option for concat and auto_combine * add tests for fill_value in concat and auto_combine * remove errant whitespace * add fill_value description to doc-string * add missing assert --- xarray/core/combine.py | 55 +++++++++++++++++++++++------------- xarray/tests/test_combine.py | 42 +++++++++++++++++++++++++++ 2 files changed, 77 insertions(+), 20 deletions(-) diff --git a/xarray/core/combine.py b/xarray/core/combine.py index 1abd14cd20b..6d922064f6f 100644 --- a/xarray/core/combine.py +++ b/xarray/core/combine.py @@ -4,7 +4,7 @@ import pandas as pd -from . import utils +from . import utils, dtypes from .alignment import align from .merge import merge from .variable import IndexVariable, Variable, as_variable @@ -14,7 +14,7 @@ def concat(objs, dim=None, data_vars='all', coords='different', compat='equals', positions=None, indexers=None, mode=None, - concat_over=None): + concat_over=None, fill_value=dtypes.NA): """Concatenate xarray objects along a new or existing dimension. Parameters @@ -66,6 +66,8 @@ def concat(objs, dim=None, data_vars='all', coords='different', List of integer arrays which specifies the integer positions to which to assign each dataset along the concatenated dimension. If not supplied, objects are concatenated in the provided order. + fill_value : scalar, optional + Value to use for newly missing values indexers, mode, concat_over : deprecated Returns @@ -117,7 +119,7 @@ def concat(objs, dim=None, data_vars='all', coords='different', else: raise TypeError('can only concatenate xarray Dataset and DataArray ' 'objects, got %s' % type(first_obj)) - return f(objs, dim, data_vars, coords, compat, positions) + return f(objs, dim, data_vars, coords, compat, positions, fill_value) def _calc_concat_dim_coord(dim): @@ -212,7 +214,8 @@ def process_subset_opt(opt, subset): return concat_over, equals -def _dataset_concat(datasets, dim, data_vars, coords, compat, positions): +def _dataset_concat(datasets, dim, data_vars, coords, compat, positions, + fill_value=dtypes.NA): """ Concatenate a sequence of datasets along a new or existing dimension """ @@ -225,7 +228,8 @@ def _dataset_concat(datasets, dim, data_vars, coords, compat, positions): dim, coord = _calc_concat_dim_coord(dim) # Make sure we're working on a copy (we'll be loading variables) datasets = [ds.copy() for ds in datasets] - datasets = align(*datasets, join='outer', copy=False, exclude=[dim]) + datasets = align(*datasets, join='outer', copy=False, exclude=[dim], + fill_value=fill_value) concat_over, equals = _calc_concat_over(datasets, dim, data_vars, coords) @@ -317,7 +321,7 @@ def ensure_common_dims(vars): def _dataarray_concat(arrays, dim, data_vars, coords, compat, - positions): + positions, fill_value=dtypes.NA): arrays = list(arrays) if data_vars != 'all': @@ -336,14 +340,15 @@ def _dataarray_concat(arrays, dim, data_vars, coords, compat, datasets.append(arr._to_temp_dataset()) ds = _dataset_concat(datasets, dim, data_vars, coords, compat, - positions) + positions, fill_value) result = arrays[0]._from_temp_dataset(ds, name) result.name = result_name(arrays) return result -def _auto_concat(datasets, dim=None, data_vars='all', coords='different'): +def _auto_concat(datasets, dim=None, data_vars='all', coords='different', + fill_value=dtypes.NA): if len(datasets) == 1 and dim is None: # There is nothing more to combine, so kick out early. return datasets[0] @@ -366,7 +371,8 @@ def _auto_concat(datasets, dim=None, data_vars='all', coords='different'): 'supply the ``concat_dim`` argument ' 'explicitly') dim, = concat_dims - return concat(datasets, dim=dim, data_vars=data_vars, coords=coords) + return concat(datasets, dim=dim, data_vars=data_vars, + coords=coords, fill_value=fill_value) _CONCAT_DIM_DEFAULT = utils.ReprObject('') @@ -442,7 +448,8 @@ def _check_shape_tile_ids(combined_tile_ids): def _combine_nd(combined_ids, concat_dims, data_vars='all', - coords='different', compat='no_conflicts'): + coords='different', compat='no_conflicts', + fill_value=dtypes.NA): """ Concatenates and merges an N-dimensional structure of datasets. @@ -472,13 +479,14 @@ def _combine_nd(combined_ids, concat_dims, data_vars='all', dim=concat_dim, data_vars=data_vars, coords=coords, - compat=compat) + compat=compat, + fill_value=fill_value) combined_ds = list(combined_ids.values())[0] return combined_ds def _auto_combine_all_along_first_dim(combined_ids, dim, data_vars, - coords, compat): + coords, compat, fill_value=dtypes.NA): # Group into lines of datasets which must be combined along dim # need to sort by _new_tile_id first for groupby to work # TODO remove all these sorted OrderedDicts once python >= 3.6 only @@ -490,7 +498,8 @@ def _auto_combine_all_along_first_dim(combined_ids, dim, data_vars, combined_ids = OrderedDict(sorted(group)) datasets = combined_ids.values() new_combined_ids[new_id] = _auto_combine_1d(datasets, dim, compat, - data_vars, coords) + data_vars, coords, + fill_value) return new_combined_ids @@ -500,18 +509,20 @@ def vars_as_keys(ds): def _auto_combine_1d(datasets, concat_dim=_CONCAT_DIM_DEFAULT, compat='no_conflicts', - data_vars='all', coords='different'): + data_vars='all', coords='different', + fill_value=dtypes.NA): # This is just the old auto_combine function (which only worked along 1D) if concat_dim is not None: dim = None if concat_dim is _CONCAT_DIM_DEFAULT else concat_dim sorted_datasets = sorted(datasets, key=vars_as_keys) grouped_by_vars = itertools.groupby(sorted_datasets, key=vars_as_keys) concatenated = [_auto_concat(list(ds_group), dim=dim, - data_vars=data_vars, coords=coords) + data_vars=data_vars, coords=coords, + fill_value=fill_value) for id, ds_group in grouped_by_vars] else: concatenated = datasets - merged = merge(concatenated, compat=compat) + merged = merge(concatenated, compat=compat, fill_value=fill_value) return merged @@ -521,7 +532,7 @@ def _new_tile_id(single_id_ds_pair): def _auto_combine(datasets, concat_dims, compat, data_vars, coords, - infer_order_from_coords, ids): + infer_order_from_coords, ids, fill_value=dtypes.NA): """ Calls logic to decide concatenation order before concatenating. """ @@ -550,12 +561,14 @@ def _auto_combine(datasets, concat_dims, compat, data_vars, coords, # Repeatedly concatenate then merge along each dimension combined = _combine_nd(combined_ids, concat_dims, compat=compat, - data_vars=data_vars, coords=coords) + data_vars=data_vars, coords=coords, + fill_value=fill_value) return combined def auto_combine(datasets, concat_dim=_CONCAT_DIM_DEFAULT, - compat='no_conflicts', data_vars='all', coords='different'): + compat='no_conflicts', data_vars='all', coords='different', + fill_value=dtypes.NA): """Attempt to auto-magically combine the given datasets into one. This method attempts to combine a list of datasets into a single entity by inspecting metadata and using a combination of concat and merge. @@ -596,6 +609,8 @@ def auto_combine(datasets, concat_dim=_CONCAT_DIM_DEFAULT, Details are in the documentation of concat coords : {'minimal', 'different', 'all' or list of str}, optional Details are in the documentation of conca + fill_value : scalar, optional + Value to use for newly missing values Returns ------- @@ -622,4 +637,4 @@ def auto_combine(datasets, concat_dim=_CONCAT_DIM_DEFAULT, return _auto_combine(datasets, concat_dims=concat_dims, compat=compat, data_vars=data_vars, coords=coords, infer_order_from_coords=infer_order_from_coords, - ids=False) + ids=False, fill_value=fill_value) diff --git a/xarray/tests/test_combine.py b/xarray/tests/test_combine.py index 1d8ed169d29..a477df0b0d4 100644 --- a/xarray/tests/test_combine.py +++ b/xarray/tests/test_combine.py @@ -7,6 +7,7 @@ import pytest from xarray import DataArray, Dataset, Variable, auto_combine, concat +from xarray.core import dtypes from xarray.core.combine import ( _auto_combine, _auto_combine_1d, _auto_combine_all_along_first_dim, _check_shape_tile_ids, _combine_nd, _infer_concat_order_from_positions, @@ -237,6 +238,20 @@ def test_concat_multiindex(self): assert expected.equals(actual) assert isinstance(actual.x.to_index(), pd.MultiIndex) + @pytest.mark.parametrize('fill_value', [dtypes.NA, 2, 2.0]) + def test_concat_fill_value(self, fill_value): + datasets = [Dataset({'a': ('x', [2, 3]), 'x': [1, 2]}), + Dataset({'a': ('x', [1, 2]), 'x': [0, 1]})] + if fill_value == dtypes.NA: + # if we supply the default, we expect the missing value for a + # float array + fill_value = np.nan + expected = Dataset({'a': (('t', 'x'), + [[fill_value, 2, 3], [1, 2, fill_value]])}, + {'x': [0, 1, 2]}) + actual = concat(datasets, dim='t', fill_value=fill_value) + assert_identical(actual, expected) + class TestConcatDataArray: def test_concat(self): @@ -306,6 +321,19 @@ def test_concat_lazy(self): assert combined.shape == (2, 3, 3) assert combined.dims == ('z', 'x', 'y') + @pytest.mark.parametrize('fill_value', [dtypes.NA, 2, 2.0]) + def test_concat_fill_value(self, fill_value): + foo = DataArray([1, 2], coords=[('x', [1, 2])]) + bar = DataArray([1, 2], coords=[('x', [1, 3])]) + if fill_value == dtypes.NA: + # if we supply the default, we expect the missing value for a + # float array + fill_value = np.nan + expected = DataArray([[1, 2, fill_value], [1, fill_value, 2]], + dims=['y', 'x'], coords={'x': [1, 2, 3]}) + actual = concat((foo, bar), dim='y', fill_value=fill_value) + assert_identical(actual, expected) + class TestAutoCombine: @@ -417,6 +445,20 @@ def test_auto_combine_no_concat(self): {'baz': [100]}) assert_identical(expected, actual) + @pytest.mark.parametrize('fill_value', [dtypes.NA, 2, 2.0]) + def test_auto_combine_fill_value(self, fill_value): + datasets = [Dataset({'a': ('x', [2, 3]), 'x': [1, 2]}), + Dataset({'a': ('x', [1, 2]), 'x': [0, 1]})] + if fill_value == dtypes.NA: + # if we supply the default, we expect the missing value for a + # float array + fill_value = np.nan + expected = Dataset({'a': (('t', 'x'), + [[fill_value, 2, 3], [1, 2, fill_value]])}, + {'x': [0, 1, 2]}) + actual = auto_combine(datasets, concat_dim='t', fill_value=fill_value) + assert_identical(expected, actual) + def assert_combined_tile_ids_equal(dict1, dict2): assert len(dict1) == len(dict2) From ae1239c58282336b311ee3a6f5d3f4ce5bacdb93 Mon Sep 17 00:00:00 2001 From: Alessandro Amici Date: Tue, 28 May 2019 23:32:26 +0200 Subject: [PATCH 06/31] cfgrib is now part of conda-forge (#2992) --- ci/requirements-py36.yml | 3 +-- ci/requirements-py37.yml | 3 +-- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/ci/requirements-py36.yml b/ci/requirements-py36.yml index 03242426a36..d6dafd8d540 100644 --- a/ci/requirements-py36.yml +++ b/ci/requirements-py36.yml @@ -24,12 +24,11 @@ dependencies: - bottleneck - zarr - pseudonetcdf>=3.0.1 - - eccodes + - cfgrib>=0.9.2 - cdms2 - pynio - iris>=1.10 - pydap - lxml - pip: - - cfgrib>=0.9.2 - mypy==0.660 diff --git a/ci/requirements-py37.yml b/ci/requirements-py37.yml index 0cece4ed6dd..c5f5d71b8e5 100644 --- a/ci/requirements-py37.yml +++ b/ci/requirements-py37.yml @@ -25,9 +25,8 @@ dependencies: - bottleneck - zarr - pseudonetcdf>=3.0.1 + - cfgrib>=0.9.2 - lxml - - eccodes - pydap - pip: - - cfgrib>=0.9.2 - mypy==0.650 From 74e5ff64171e84a2da3984f512d2134a233240e2 Mon Sep 17 00:00:00 2001 From: Alan Brammer Date: Sat, 1 Jun 2019 01:51:16 -0600 Subject: [PATCH 07/31] Add strftime() to datetime accessor with cftimeindex and dask support (#2989) * Add strftime to DatetimeAccessor Includes Dask support and cftime values * add strftime to cftimeindex change datetime.accessor to pass thru to cftimeindex.strftime * pep8 formatting cleanup * Improve cftimeindex and dt.strftime docstrings * Minor code cleanups remove extra strftime_through_arrays function * edits to make docstring prettier for pep8 * Add strftime accessor and cftimeindex method to whats new * Apply suggested doc edits as single commit Add cftimeindex.strftime to api-hidden.rst * Include missing backtick in what's new docs Co-Authored-By: Spencer Clark * "an Index" not "a Index" * Add examples to User Gude docs Added some short examples to time-series and weather-climate pages * Update doc/whats-new.rst Co-Authored-By: Spencer Clark --- doc/api-hidden.rst | 1 + doc/time-series.rst | 9 +++++ doc/weather-climate.rst | 12 ++++++ doc/whats-new.rst | 7 ++++ xarray/coding/cftimeindex.py | 29 ++++++++++++++ xarray/core/accessors.py | 68 ++++++++++++++++++++++++++++++++ xarray/tests/test_accessors.py | 24 +++++++++++ xarray/tests/test_cftimeindex.py | 11 ++++++ 8 files changed, 161 insertions(+) diff --git a/doc/api-hidden.rst b/doc/api-hidden.rst index 4b2fed8be37..8f82b30a442 100644 --- a/doc/api-hidden.rst +++ b/doc/api-hidden.rst @@ -153,3 +153,4 @@ CFTimeIndex.shift CFTimeIndex.to_datetimeindex + CFTimeIndex.strftime diff --git a/doc/time-series.rst b/doc/time-series.rst index 53efcd45ba2..e198887dd0d 100644 --- a/doc/time-series.rst +++ b/doc/time-series.rst @@ -152,6 +152,15 @@ __ http://pandas.pydata.org/pandas-docs/stable/timeseries.html#offset-aliases ds['time'].dt.floor('D') +The ``.dt`` accessor can also be used to generate formatted datetime strings +for arrays utilising the same formatting as the standard `datetime.strftime`_. + +.. _datetime.strftime: https://docs.python.org/3/library/datetime.html#strftime-strptime-behavior + +.. ipython:: python + + ds['time'].dt.strftime('%a, %b %d %H:%M') + .. _resampling: Resampling and grouped operations diff --git a/doc/weather-climate.rst b/doc/weather-climate.rst index 1950ba62ffb..a17ecd2f2a4 100644 --- a/doc/weather-climate.rst +++ b/doc/weather-climate.rst @@ -71,6 +71,18 @@ instance, we can create the same dates and DataArray we created above using: dates = xr.cftime_range(start='0001', periods=24, freq='MS', calendar='noleap') da = xr.DataArray(np.arange(24), coords=[dates], dims=['time'], name='foo') +With :py:meth:`~xarray.CFTimeIndex.strftime` we can also easily generate formatted strings from +the datetime values of a :py:class:`~xarray.CFTimeIndex` directly or through the +:py:meth:`~xarray.DataArray.dt` accessor for a :py:class:`~xarray.DataArray` +using the same formatting as the standard `datetime.strftime`_ convention . + +.. _datetime.strftime: https://docs.python.org/3/library/datetime.html#strftime-strptime-behavior + +.. ipython:: python + + dates.strftime('%c') + da['time'].dt.strftime('%Y%m%d') + For data indexed by a :py:class:`~xarray.CFTimeIndex` xarray currently supports: - `Partial datetime string indexing`_ using strictly `ISO 8601-format`_ partial diff --git a/doc/whats-new.rst b/doc/whats-new.rst index dfdca55d218..55773af92b3 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -42,6 +42,13 @@ Enhancements helpful for avoiding file-lock errors when trying to write to files opened using ``open_dataset()`` or ``open_dataarray()``. (:issue:`2887`) By `Dan Nowacki `_. +- Added ``strftime`` method to ``.dt`` accessor, making it simpler to hand a + datetime ``DataArray`` to other code expecting formatted dates and times. + (:issue:`2090`). By `Alan Brammer `_ and + `Ryan May `_. +- Like :py:class:`pandas.DatetimeIndex`, :py:class:`CFTimeIndex` now supports a + :py:meth:`~xarray.CFTimeIndex.strftime` method to return an index of string + formatted datetimes. By `Alan Brammer `_. Bug fixes ~~~~~~~~~ diff --git a/xarray/coding/cftimeindex.py b/xarray/coding/cftimeindex.py index 1456f8ce3b3..6ce7831a5bc 100644 --- a/xarray/coding/cftimeindex.py +++ b/xarray/coding/cftimeindex.py @@ -476,6 +476,35 @@ def to_datetimeindex(self, unsafe=False): 'dates.'.format(calendar), RuntimeWarning, stacklevel=2) return pd.DatetimeIndex(nptimes) + def strftime(self, date_format): + """ + Return an Index of formatted strings specified by date_format, which + supports the same string format as the python standard library. Details + of the string format can be found in `python string format doc + `__ + + Parameters + ---------- + date_format : str + Date format string (e.g. "%Y-%m-%d") + + Returns + ------- + Index + Index of formatted strings + + Examples + -------- + >>> rng = xr.cftime_range(start='2000', periods=5, freq='2MS', + ... calendar='noleap') + >>> rng.strftime('%B %d, %Y, %r') + Index(['January 01, 2000, 12:00:00 AM', 'March 01, 2000, 12:00:00 AM', + 'May 01, 2000, 12:00:00 AM', 'July 01, 2000, 12:00:00 AM', + 'September 01, 2000, 12:00:00 AM'], + dtype='object') + """ + return pd.Index([date.strftime(date_format) for date in self._data]) + def _parse_iso8601_without_reso(date_type, datetime_str): date, _ = _parse_iso8601_with_reso(date_type, datetime_str) diff --git a/xarray/core/accessors.py b/xarray/core/accessors.py index 640060fafe5..806e1579c3a 100644 --- a/xarray/core/accessors.py +++ b/xarray/core/accessors.py @@ -110,6 +110,38 @@ def _round_field(values, name, freq): return _round_series(values, name, freq) +def _strftime_through_cftimeindex(values, date_format): + """Coerce an array of cftime-like values to a CFTimeIndex + and access requested datetime component + """ + from ..coding.cftimeindex import CFTimeIndex + values_as_cftimeindex = CFTimeIndex(values.ravel()) + + field_values = values_as_cftimeindex.strftime(date_format) + return field_values.values.reshape(values.shape) + + +def _strftime_through_series(values, date_format): + """Coerce an array of datetime-like values to a pandas Series and + apply string formatting + """ + values_as_series = pd.Series(values.ravel()) + strs = values_as_series.dt.strftime(date_format) + return strs.values.reshape(values.shape) + + +def _strftime(values, date_format): + if is_np_datetime_like(values.dtype): + access_method = _strftime_through_series + else: + access_method = _strftime_through_cftimeindex + if isinstance(values, dask_array_type): + from dask.array import map_blocks + return map_blocks(access_method, values, date_format) + else: + return access_method(values, date_format) + + class DatetimeAccessor: """Access datetime fields for DataArrays with datetime-like dtypes. @@ -256,3 +288,39 @@ def round(self, freq): Array-like of datetime fields accessed for each element in values ''' return self._tslib_round_accessor("round", freq) + + def strftime(self, date_format): + ''' + Return an array of formatted strings specified by date_format, which + supports the same string format as the python standard library. Details + of the string format can be found in `python string format doc + `__ + + Parameters + ---------- + date_format : str + date format string (e.g. "%Y-%m-%d") + + Returns + ------- + formatted strings : same type as values + Array-like of strings formatted for each element in values + + Examples + -------- + >>> rng = xr.Dataset({'time': datetime.datetime(2000, 1, 1)}) + >>> rng['time'].dt.strftime('%B %d, %Y, %r') + + array('January 01, 2000, 12:00:00 AM', dtype=object) + """ + + ''' + obj_type = type(self._obj) + + result = _strftime(self._obj.data, date_format) + + return obj_type( + result, + name="strftime", + coords=self._obj.coords, + dims=self._obj.dims) diff --git a/xarray/tests/test_accessors.py b/xarray/tests/test_accessors.py index 6bda5772143..09041a6a69f 100644 --- a/xarray/tests/test_accessors.py +++ b/xarray/tests/test_accessors.py @@ -42,6 +42,10 @@ def test_field_access(self): assert_equal(days, self.data.time.dt.day) assert_equal(hours, self.data.time.dt.hour) + def test_strftime(self): + assert ('2000-01-01 01:00:00' == self.data.time.dt.strftime( + '%Y-%m-%d %H:%M:%S')[1]) + def test_not_datetime_type(self): nontime_data = self.data.copy() int_data = np.arange(len(self.data.time)).astype('int8') @@ -60,6 +64,7 @@ def test_dask_field_access(self): floor = self.times_data.dt.floor('D') ceil = self.times_data.dt.ceil('D') round = self.times_data.dt.round('D') + strftime = self.times_data.dt.strftime('%Y-%m-%d %H:%M:%S') dask_times_arr = da.from_array(self.times_arr, chunks=(5, 5, 50)) dask_times_2d = xr.DataArray(dask_times_arr, @@ -73,12 +78,14 @@ def test_dask_field_access(self): dask_floor = dask_times_2d.dt.floor('D') dask_ceil = dask_times_2d.dt.ceil('D') dask_round = dask_times_2d.dt.round('D') + dask_strftime = dask_times_2d.dt.strftime('%Y-%m-%d %H:%M:%S') # Test that the data isn't eagerly evaluated assert isinstance(dask_year.data, da.Array) assert isinstance(dask_month.data, da.Array) assert isinstance(dask_day.data, da.Array) assert isinstance(dask_hour.data, da.Array) + assert isinstance(dask_strftime.data, da.Array) # Double check that outcome chunksize is unchanged dask_chunks = dask_times_2d.chunks @@ -86,6 +93,7 @@ def test_dask_field_access(self): assert dask_month.data.chunks == dask_chunks assert dask_day.data.chunks == dask_chunks assert dask_hour.data.chunks == dask_chunks + assert dask_strftime.data.chunks == dask_chunks # Check the actual output from the accessors assert_equal(years, dask_year.compute()) @@ -95,6 +103,7 @@ def test_dask_field_access(self): assert_equal(floor, dask_floor.compute()) assert_equal(ceil, dask_ceil.compute()) assert_equal(round, dask_round.compute()) + assert_equal(strftime, dask_strftime.compute()) def test_seasons(self): dates = pd.date_range(start="2000/01/01", freq="M", periods=12) @@ -169,6 +178,21 @@ def test_field_access(data, field): assert_equal(result, expected) +@pytest.mark.skipif(not has_cftime, reason='cftime not installed') +def test_cftime_strftime_access(data): + """ compare cftime formatting against datetime formatting """ + date_format = '%Y%m%d%H' + result = data.time.dt.strftime(date_format) + datetime_array = xr.DataArray( + xr.coding.cftimeindex.CFTimeIndex( + data.time.values).to_datetimeindex(), + name="stftime", + coords=data.time.coords, + dims=data.time.dims) + expected = datetime_array.dt.strftime(date_format) + assert_equal(result, expected) + + @pytest.mark.skipif(not has_dask, reason='dask not installed') @pytest.mark.skipif(not has_cftime, reason='cftime not installed') @pytest.mark.parametrize('field', ['year', 'month', 'day', 'hour', diff --git a/xarray/tests/test_cftimeindex.py b/xarray/tests/test_cftimeindex.py index c5cdf0a3fee..999af4c9f86 100644 --- a/xarray/tests/test_cftimeindex.py +++ b/xarray/tests/test_cftimeindex.py @@ -785,6 +785,17 @@ def test_parse_array_of_cftime_strings(): np.testing.assert_array_equal(result, expected) +@pytest.mark.skipif(not has_cftime, reason='cftime not installed') +@pytest.mark.parametrize('calendar', _ALL_CALENDARS) +def test_strftime_of_cftime_array(calendar): + date_format = '%Y%m%d%H%M' + cf_values = xr.cftime_range('2000', periods=5, calendar=calendar) + dt_values = pd.date_range('2000', periods=5) + expected = dt_values.strftime(date_format) + result = cf_values.strftime(date_format) + assert result.equals(expected) + + @pytest.mark.skipif(not has_cftime, reason='cftime not installed') @pytest.mark.parametrize('calendar', _ALL_CALENDARS) @pytest.mark.parametrize('unsafe', [False, True]) From 5b3a41d5761edb2240df5a4475196e4939b33719 Mon Sep 17 00:00:00 2001 From: Maximilian Roos <5635139+max-sixty@users.noreply.github.com> Date: Sat, 1 Jun 2019 15:42:20 -0400 Subject: [PATCH 08/31] Implement @ operator for DataArray (#2987) * implement @ operator for DataArray * flake * whatsnew * add doc example and rmatmul --- doc/computation.rst | 6 ++++++ doc/whats-new.rst | 2 ++ xarray/core/dataarray.py | 8 ++++++++ xarray/tests/test_dataarray.py | 13 +++++++++++++ 4 files changed, 29 insertions(+) diff --git a/doc/computation.rst b/doc/computation.rst index 2d41479f67f..a999318406e 100644 --- a/doc/computation.rst +++ b/doc/computation.rst @@ -45,6 +45,12 @@ Use :py:func:`~xarray.where` to conditionally switch between values: xr.where(arr > 0, 'positive', 'negative') +Use `@` to perform matrix multiplication: + +.. ipython:: python + + arr @ arr + Data arrays also implement many :py:class:`numpy.ndarray` methods: .. ipython:: python diff --git a/doc/whats-new.rst b/doc/whats-new.rst index 55773af92b3..6a90e3b3c4c 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -21,6 +21,8 @@ v0.12.2 (unreleased) Enhancements ~~~~~~~~~~~~ +- Enable `@` operator for DataArray. This is equivalent to :py:meth:`DataArray.dot` + By `Maximilian Roos `_. - Add ``fill_value`` argument for reindex, align, and merge operations to enable custom fill values. (:issue:`2876`) By `Zach Griffith `_. diff --git a/xarray/core/dataarray.py b/xarray/core/dataarray.py index 15abdaf4a92..cab5612dfe1 100644 --- a/xarray/core/dataarray.py +++ b/xarray/core/dataarray.py @@ -2014,6 +2014,14 @@ def __array_wrap__(self, obj, context=None): new_var = self.variable.__array_wrap__(obj, context) return self._replace(new_var) + def __matmul__(self, obj): + return self.dot(obj) + + def __rmatmul__(self, other): + # currently somewhat duplicative, as only other DataArrays are + # compatible with matmul + return computation.dot(other, self) + @staticmethod def _unary_op(f): @functools.wraps(f) diff --git a/xarray/tests/test_dataarray.py b/xarray/tests/test_dataarray.py index 43af27d0696..fb7c9676d2a 100644 --- a/xarray/tests/test_dataarray.py +++ b/xarray/tests/test_dataarray.py @@ -3456,6 +3456,19 @@ def test_dot(self): with pytest.raises(TypeError): da.dot(dm.values) + def test_matmul(self): + + # copied from above (could make a fixture) + x = np.linspace(-3, 3, 6) + y = np.linspace(-3, 3, 5) + z = range(4) + da_vals = np.arange(6 * 5 * 4).reshape((6, 5, 4)) + da = DataArray(da_vals, coords=[x, y, z], dims=['x', 'y', 'z']) + + result = da @ da + expected = da.dot(da) + assert_identical(result, expected) + def test_binary_op_join_setting(self): dim = 'x' align_type = "outer" From f84c514175f4858caf7dbf2afd2d8fe551208fa0 Mon Sep 17 00:00:00 2001 From: Kevin Squire Date: Tue, 4 Jun 2019 07:47:08 -0700 Subject: [PATCH 09/31] Add examples for `DataArrayRolling.reduce()` (#2968) * Add Examples to DataArrayRolling.reduce * Adds two examples to DataArrayRolling.reduce(); the second example shows the interaction of reduce() and min_periods * Update Rolling window operations documentation * Add a longer description and examples for the `center` and `min_period` parameters to `DataArray.rolling()` --- doc/computation.rst | 27 +++++++++++++++++++++------ xarray/core/rolling.py | 23 +++++++++++++++++++++++ 2 files changed, 44 insertions(+), 6 deletions(-) diff --git a/doc/computation.rst b/doc/computation.rst index a999318406e..3100925a7d3 100644 --- a/doc/computation.rst +++ b/doc/computation.rst @@ -149,20 +149,35 @@ name of the dimension as a key (e.g. ``y``) and the window size as the value arr.rolling(y=3) -The label position and minimum number of periods in the rolling window are -controlled by the ``center`` and ``min_periods`` arguments: +Aggregation and summary methods can be applied directly to the ``Rolling`` +object: .. ipython:: python - arr.rolling(y=3, min_periods=2, center=True) + r = arr.rolling(y=3) + r.reduce(np.std) + r.mean() -Aggregation and summary methods can be applied directly to the ``Rolling`` object: +Aggregation results are assigned the coordinate at the end of each window by +default, but can be centered by passing ``center=True`` when constructing the + ``Rolling`` object: .. ipython:: python - r = arr.rolling(y=3) + r = arr.rolling(y=3, center=True) + r.mean() + +As can be seen above, aggregations of windows which overlap the border of the +array produce ``nan``s. Setting ``min_periods`` in the call to ``rolling`` +changes the minimum number of observations within the window required to have +a value when aggregating: + +.. ipython:: python + + r = arr.rolling(y=3, min_periods=2) + r.mean() + r = arr.rolling(y=3, center=True, min_periods=2) r.mean() - r.reduce(np.std) Note that rolling window aggregations are faster when bottleneck_ is installed. diff --git a/xarray/core/rolling.py b/xarray/core/rolling.py index c113cfebe2a..a884963bf06 100644 --- a/xarray/core/rolling.py +++ b/xarray/core/rolling.py @@ -211,6 +211,29 @@ def reduce(self, func, **kwargs): ------- reduced : DataArray Array with summarized data. + + Examples + -------- + >>> da = DataArray(np.arange(8).reshape(2, 4), dims=('a', 'b')) + >>> + >>> rolling = da.rolling(b=3) + >>> rolling.construct('window_dim') + + array([[[np.nan, np.nan, 0], [np.nan, 0, 1], [0, 1, 2], [1, 2, 3]], + [[np.nan, np.nan, 4], [np.nan, 4, 5], [4, 5, 6], [5, 6, 7]]]) + Dimensions without coordinates: a, b, window_dim + >>> + >>> rolling.reduce(np.sum) + + array([[nan, nan, 3., 6.], + [nan, nan, 15., 18.]]) + Dimensions without coordinates: a, b + >>> + >>> rolling = da.rolling(b=3, min_periods=1) + >>> rolling.reduce(np.nansum) + + array([[ 0., 1., 3., 6.], + [ 4., 9., 15., 18.]]) """ rolling_dim = utils.get_temp_dimname(self.obj.dims, '_rolling_dim') windows = self.construct(rolling_dim) From 5343ccc4bef77020a8181c9da0fdee6bcae35b5f Mon Sep 17 00:00:00 2001 From: James McCreight Date: Tue, 4 Jun 2019 14:41:50 -0600 Subject: [PATCH 10/31] Contiguous store with unlim dim bug fix (#2941) * Contiguous store with unlim dim bug fix * It is new: Bug fix to 1849, need chunking for vars with unlim dims in netcdf4. * pep8 spoke * white space addition * rm warning, combine if * check invalid encoding * pep8 strikes * Update netCDF4_.py * Line wrapping --- doc/whats-new.rst | 3 +++ xarray/backends/netCDF4_.py | 6 ++++++ xarray/tests/test_backends.py | 5 +++++ 3 files changed, 14 insertions(+) diff --git a/doc/whats-new.rst b/doc/whats-new.rst index 6a90e3b3c4c..a77a0334de7 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -55,6 +55,9 @@ Enhancements Bug fixes ~~~~~~~~~ +- NetCDF4 output: variables with unlimited dimensions must be chunked (not + contiguous) on output. (:issue:`1849`) + By `James McCreight `_. - indexing with an empty list creates an object with zero-length axis (:issue:`2882`) By `Mayeul d'Avezac `_. - Return correct count for scalar datetime64 arrays (:issue:`2770`) diff --git a/xarray/backends/netCDF4_.py b/xarray/backends/netCDF4_.py index b3bab9617ee..e411fd3a80e 100644 --- a/xarray/backends/netCDF4_.py +++ b/xarray/backends/netCDF4_.py @@ -210,6 +210,11 @@ def _extract_nc4_variable_encoding(variable, raise_on_invalid=False, if chunks_too_big or changed_shape: del encoding['chunksizes'] + var_has_unlim_dim = any(dim in unlimited_dims for dim in variable.dims) + if (not raise_on_invalid and var_has_unlim_dim + and 'contiguous' in encoding.keys()): + del encoding['contiguous'] + for k in safe_to_drop: if k in encoding: del encoding[k] @@ -445,6 +450,7 @@ def prepare_variable(self, name, variable, check_encoding=False, encoding = _extract_nc4_variable_encoding( variable, raise_on_invalid=check_encoding, unlimited_dims=unlimited_dims) + if name in self.ds.variables: nc4_var = self.ds.variables[name] else: diff --git a/xarray/tests/test_backends.py b/xarray/tests/test_backends.py index f31d3bf4f9b..e63cf45b5ed 100644 --- a/xarray/tests/test_backends.py +++ b/xarray/tests/test_backends.py @@ -3525,6 +3525,11 @@ def test_extract_nc4_variable_encoding(self): encoding = _extract_nc4_variable_encoding(var, raise_on_invalid=True) assert {'shuffle': True} == encoding + # Variables with unlim dims must be chunked on output. + var = xr.Variable(('x',), [1, 2, 3], {}, {'contiguous': True}) + encoding = _extract_nc4_variable_encoding(var, unlimited_dims=('x',)) + assert {} == encoding + def test_extract_h5nc_encoding(self): # not supported with h5netcdf (yet) var = xr.Variable(('x',), [1, 2, 3], {}, From 44011c9857b249b65b5f403d2791e198e3d67a87 Mon Sep 17 00:00:00 2001 From: Karel van de Plassche Date: Thu, 6 Jun 2019 22:35:49 +0200 Subject: [PATCH 11/31] Fixes #2198: Drop chunksizes when only when original_shape is different, not when it isn't found (#2207) * Fixes #2198: Drop chunksizes when original_shape is different Before this fix chunksizes was dropped even when original_shape was not found in encoding * More direct has_original_shape check * Fixed typo * Added test if chunksizes is kept when no original shape * Fix typo in test name Co-Authored-By: Deepak Cherian * Fix keep_chunksizes_if_no_orignal_shape test by using native open_dataset * Added entry in whats-new * Use roundtrip mechanism in chunksizes conservation test --- doc/whats-new.rst | 4 ++++ xarray/backends/netCDF4_.py | 4 +++- xarray/tests/test_backends.py | 12 ++++++++++++ 3 files changed, 19 insertions(+), 1 deletion(-) diff --git a/doc/whats-new.rst b/doc/whats-new.rst index a77a0334de7..1df3a79f0fb 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -21,6 +21,10 @@ v0.12.2 (unreleased) Enhancements ~~~~~~~~~~~~ + +- netCDF chunksizes are now only dropped when original_shape is different, + not when it isn't found. (:issue:`2207`) + By `Karel van de Plassche `_. - Enable `@` operator for DataArray. This is equivalent to :py:meth:`DataArray.dot` By `Maximilian Roos `_. - Add ``fill_value`` argument for reindex, align, and merge operations diff --git a/xarray/backends/netCDF4_.py b/xarray/backends/netCDF4_.py index e411fd3a80e..2396523dca7 100644 --- a/xarray/backends/netCDF4_.py +++ b/xarray/backends/netCDF4_.py @@ -206,7 +206,9 @@ def _extract_nc4_variable_encoding(variable, raise_on_invalid=False, chunks_too_big = any( c > d and dim not in unlimited_dims for c, d, dim in zip(chunksizes, variable.shape, variable.dims)) - changed_shape = encoding.get('original_shape') != variable.shape + has_original_shape = 'original_shape' in encoding + changed_shape = (has_original_shape and + encoding.get('original_shape') != variable.shape) if chunks_too_big or changed_shape: del encoding['chunksizes'] diff --git a/xarray/tests/test_backends.py b/xarray/tests/test_backends.py index e63cf45b5ed..ad66ecf1286 100644 --- a/xarray/tests/test_backends.py +++ b/xarray/tests/test_backends.py @@ -1134,6 +1134,18 @@ def test_encoding_kwarg_compression(self): assert ds.x.encoding == {} + def test_keep_chunksizes_if_no_original_shape(self): + ds = Dataset({'x': [1, 2, 3]}) + chunksizes = (2, ) + ds.variables['x'].encoding = { + 'chunksizes': chunksizes + } + + with self.roundtrip(ds) as actual: + assert_identical(ds, actual) + assert_array_equal(ds['x'].encoding['chunksizes'], + actual['x'].encoding['chunksizes']) + def test_encoding_chunksizes_unlimited(self): # regression test for GH1225 ds = Dataset({'x': [1, 2, 3], 'y': ('x', [2, 3, 4])}) From 7e649e415e8023b05eb56b3ffadd0d930381f6e0 Mon Sep 17 00:00:00 2001 From: nullptr <3621629+0x0L@users.noreply.github.com> Date: Mon, 10 Jun 2019 06:48:44 +0200 Subject: [PATCH 12/31] fix safe_cast_to_index (#3001) --- doc/whats-new.rst | 4 +++- xarray/coding/cftimeindex.py | 4 ++-- xarray/core/utils.py | 2 +- xarray/tests/test_dataarray.py | 6 ++++++ 4 files changed, 12 insertions(+), 4 deletions(-) diff --git a/doc/whats-new.rst b/doc/whats-new.rst index 1df3a79f0fb..151911565e5 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -70,8 +70,10 @@ Bug fixes By `Deepak Cherian `_. -- Removed usages of `pytest.config`, which is deprecated (:issue:`2988`:) +- Removed usages of `pytest.config`, which is deprecated (:issue:`2988`) By `Maximilian Roos `_. +- Fixed performance issues with cftime installed (:issue:`3000`) + By `0x0L `_. .. _whats-new.0.12.1: diff --git a/xarray/coding/cftimeindex.py b/xarray/coding/cftimeindex.py index 6ce7831a5bc..cf10d6238aa 100644 --- a/xarray/coding/cftimeindex.py +++ b/xarray/coding/cftimeindex.py @@ -184,7 +184,7 @@ def get_date_type(self): def assert_all_valid_date_type(data): import cftime - if data.size: + if len(data) > 0: sample = data[0] date_type = type(sample) if not isinstance(sample, cftime.datetime): @@ -229,12 +229,12 @@ class CFTimeIndex(pd.Index): date_type = property(get_date_type) def __new__(cls, data, name=None): + assert_all_valid_date_type(data) if name is None and hasattr(data, 'name'): name = data.name result = object.__new__(cls) result._data = np.array(data, dtype='O') - assert_all_valid_date_type(result._data) result.name = name return result diff --git a/xarray/core/utils.py b/xarray/core/utils.py index 94787dd35e2..386019a3dbf 100644 --- a/xarray/core/utils.py +++ b/xarray/core/utils.py @@ -62,7 +62,7 @@ def wrapper(*args, **kwargs): def _maybe_cast_to_cftimeindex(index: pd.Index) -> pd.Index: from ..coding.cftimeindex import CFTimeIndex - if index.dtype == 'O': + if len(index) > 0 and index.dtype == 'O': try: return CFTimeIndex(index) except (ImportError, TypeError): diff --git a/xarray/tests/test_dataarray.py b/xarray/tests/test_dataarray.py index fb7c9676d2a..3580155781a 100644 --- a/xarray/tests/test_dataarray.py +++ b/xarray/tests/test_dataarray.py @@ -1760,6 +1760,12 @@ def test_stack_unstack(self): orig = DataArray([[0, 1], [2, 3]], dims=['x', 'y'], attrs={'foo': 2}) assert_identical(orig, orig.unstack()) + # test GH3000 + a = orig[:0, :1].stack(dim=('x', 'y')).dim.to_index() + b = pd.MultiIndex(levels=[pd.Int64Index([]), pd.Int64Index([0])], + labels=[[], []], names=['x', 'y']) + pd.util.testing.assert_index_equal(a, b) + actual = orig.stack(z=['x', 'y']).unstack('z').drop(['x', 'y']) assert_identical(orig, actual) From fa55060802a2c772eecef96b6c4ce672dca9dc81 Mon Sep 17 00:00:00 2001 From: nullptr <3621629+0x0L@users.noreply.github.com> Date: Mon, 10 Jun 2019 15:11:10 +0200 Subject: [PATCH 13/31] str accessor (#2991) --- doc/api.rst | 10 + doc/whats-new.rst | 2 + xarray/core/{accessors.py => accessor_dt.py} | 6 +- xarray/core/accessor_str.py | 958 ++++++++++++++++++ xarray/core/dataarray.py | 4 +- ...{test_accessors.py => test_accessor_dt.py} | 0 xarray/tests/test_accessor_str.py | 659 ++++++++++++ 7 files changed, 1635 insertions(+), 4 deletions(-) rename xarray/core/{accessors.py => accessor_dt.py} (98%) create mode 100644 xarray/core/accessor_str.py rename xarray/tests/{test_accessors.py => test_accessor_dt.py} (100%) create mode 100644 xarray/tests/test_accessor_str.py diff --git a/doc/api.rst b/doc/api.rst index 0e766f2cf9a..33c8d9d3ceb 100644 --- a/doc/api.rst +++ b/doc/api.rst @@ -324,6 +324,7 @@ Computation DataArray.quantile DataArray.differentiate DataArray.integrate + DataArray.str **Aggregation**: :py:attr:`~DataArray.all` @@ -557,6 +558,15 @@ Resample objects also implement the GroupBy interface core.resample.DatasetResample.nearest core.resample.DatasetResample.pad +Accessors +========= + +.. autosummary:: + :toctree: generated/ + + core.accessor_dt.DatetimeAccessor + core.accessor_str.StringAccessor + Custom Indexes ============== .. autosummary:: diff --git a/doc/whats-new.rst b/doc/whats-new.rst index 151911565e5..f397178ca5d 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -55,6 +55,8 @@ Enhancements - Like :py:class:`pandas.DatetimeIndex`, :py:class:`CFTimeIndex` now supports a :py:meth:`~xarray.CFTimeIndex.strftime` method to return an index of string formatted datetimes. By `Alan Brammer `_. +- Add ``.str`` accessor to DataArrays for string related manipulations. + By `0x0L `_. Bug fixes ~~~~~~~~~ diff --git a/xarray/core/accessors.py b/xarray/core/accessor_dt.py similarity index 98% rename from xarray/core/accessors.py rename to xarray/core/accessor_dt.py index 806e1579c3a..01cddae188f 100644 --- a/xarray/core/accessors.py +++ b/xarray/core/accessor_dt.py @@ -165,13 +165,13 @@ class DatetimeAccessor: """ - def __init__(self, xarray_obj): - if not _contains_datetime_like_objects(xarray_obj): + def __init__(self, obj): + if not _contains_datetime_like_objects(obj): raise TypeError("'dt' accessor only available for " "DataArray with datetime64 timedelta64 dtype or " "for arrays containing cftime datetime " "objects.") - self._obj = xarray_obj + self._obj = obj def _tslib_field_accessor(name, docstring=None, dtype=None): def f(self, dtype=dtype): diff --git a/xarray/core/accessor_str.py b/xarray/core/accessor_str.py new file mode 100644 index 00000000000..564593a032e --- /dev/null +++ b/xarray/core/accessor_str.py @@ -0,0 +1,958 @@ +# The StringAccessor class defined below is an adaptation of the +# pandas string methods source code (see pd.core.strings) + +# For reference, here is a copy of the pandas copyright notice: + +# (c) 2011-2012, Lambda Foundry, Inc. and PyData Development Team +# All rights reserved. + +# Copyright (c) 2008-2011 AQR Capital Management, LLC +# All rights reserved. + +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are +# met: + +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. + +# * Redistributions in binary form must reproduce the above +# copyright notice, this list of conditions and the following +# disclaimer in the documentation and/or other materials provided +# with the distribution. + +# * Neither the name of the copyright holder nor the names of any +# contributors may be used to endorse or promote products derived +# from this software without specific prior written permission. + +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDER AND CONTRIBUTORS +# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import codecs +import re +import textwrap + +import numpy as np + +from .computation import apply_ufunc + + +_cpython_optimized_encoders = ( + "utf-8", "utf8", "latin-1", "latin1", "iso-8859-1", "mbcs", "ascii" +) +_cpython_optimized_decoders = _cpython_optimized_encoders + ( + "utf-16", "utf-32" +) + + +def _is_str_like(x): + return isinstance(x, str) or isinstance(x, bytes) + + +class StringAccessor: + """Vectorized string functions for string-like arrays. + + Similar to pandas, fields can be accessed through the `.str` attribute + for applicable DataArrays. + + >>> da = xr.DataArray(['some', 'text', 'in', 'an', 'array']) + >>> ds.str.len() + + array([4, 4, 2, 2, 5]) + Dimensions without coordinates: dim_0 + + """ + + def __init__(self, obj): + self._obj = obj + + def _apply(self, f, dtype=None): + # TODO handling of na values ? + if dtype is None: + dtype = self._obj.dtype + + g = np.vectorize(f, otypes=[dtype]) + return apply_ufunc( + g, self._obj, dask='parallelized', output_dtypes=[dtype]) + + def len(self): + ''' + Compute the length of each element in the array. + + Returns + ------- + lengths array : array of int + ''' + return self._apply(len, dtype=int) + + def __getitem__(self, key): + if isinstance(key, slice): + return self.slice(start=key.start, stop=key.stop, step=key.step) + else: + return self.get(key) + + def get(self, i): + ''' + Extract element from indexable in each element in the array. + + Parameters + ---------- + i : int + Position of element to extract. + default : optional + Value for out-of-range index. If not specified (None) defaults to + an empty string. + + Returns + ------- + items : array of objects + ''' + obj = slice(-1, None) if i == -1 else slice(i, i + 1) + return self._apply(lambda x: x[obj]) + + def slice(self, start=None, stop=None, step=None): + ''' + Slice substrings from each element in the array. + + Parameters + ---------- + start : int, optional + Start position for slice operation. + stop : int, optional + Stop position for slice operation. + step : int, optional + Step size for slice operation. + + Returns + ------- + sliced strings : same type as values + ''' + s = slice(start, stop, step) + f = lambda x: x[s] + return self._apply(f) + + def slice_replace(self, start=None, stop=None, repl=''): + ''' + Replace a positional slice of a string with another value. + + Parameters + ---------- + start : int, optional + Left index position to use for the slice. If not specified (None), + the slice is unbounded on the left, i.e. slice from the start + of the string. + stop : int, optional + Right index position to use for the slice. If not specified (None), + the slice is unbounded on the right, i.e. slice until the + end of the string. + repl : str, optional + String for replacement. If not specified, the sliced region + is replaced with an empty string. + + Returns + ------- + replaced : same type as values + ''' + repl = self._obj.dtype.type(repl) + + def f(x): + if len(x[start:stop]) == 0: + local_stop = start + else: + local_stop = stop + y = self._obj.dtype.type('') + if start is not None: + y += x[:start] + y += repl + if stop is not None: + y += x[local_stop:] + return y + + return self._apply(f) + + def capitalize(self): + ''' + Convert strings in the array to be capitalized. + + Returns + ------- + capitalized : same type as values + ''' + return self._apply(lambda x: x.capitalize()) + + def lower(self): + ''' + Convert strings in the array to lowercase. + + Returns + ------- + lowerd : same type as values + ''' + return self._apply(lambda x: x.lower()) + + def swapcase(self): + ''' + Convert strings in the array to be swapcased. + + Returns + ------- + swapcased : same type as values + ''' + return self._apply(lambda x: x.swapcase()) + + def title(self): + ''' + Convert strings in the array to titlecase. + + Returns + ------- + titled : same type as values + ''' + return self._apply(lambda x: x.title()) + + def upper(self): + ''' + Convert strings in the array to uppercase. + + Returns + ------- + uppered : same type as values + ''' + return self._apply(lambda x: x.upper()) + + def isalnum(self): + ''' + Check whether all characters in each string are alphanumeric. + + Returns + ------- + isalnum : array of bool + Array of boolean values with the same shape as the original array. + ''' + return self._apply(lambda x: x.isalnum(), dtype=bool) + + def isalpha(self): + ''' + Check whether all characters in each string are alphabetic. + + Returns + ------- + isalpha : array of bool + Array of boolean values with the same shape as the original array. + ''' + return self._apply(lambda x: x.isalpha(), dtype=bool) + + def isdecimal(self): + ''' + Check whether all characters in each string are decimal. + + Returns + ------- + isdecimal : array of bool + Array of boolean values with the same shape as the original array. + ''' + return self._apply(lambda x: x.isdecimal(), dtype=bool) + + def isdigit(self): + ''' + Check whether all characters in each string are digits. + + Returns + ------- + isdigit : array of bool + Array of boolean values with the same shape as the original array. + ''' + return self._apply(lambda x: x.isdigit(), dtype=bool) + + def islower(self): + ''' + Check whether all characters in each string are lowercase. + + Returns + ------- + islower : array of bool + Array of boolean values with the same shape as the original array. + ''' + return self._apply(lambda x: x.islower(), dtype=bool) + + def isnumeric(self): + ''' + Check whether all characters in each string are numeric. + + Returns + ------- + isnumeric : array of bool + Array of boolean values with the same shape as the original array. + ''' + return self._apply(lambda x: x.isnumeric(), dtype=bool) + + def isspace(self): + ''' + Check whether all characters in each string are spaces. + + Returns + ------- + isspace : array of bool + Array of boolean values with the same shape as the original array. + ''' + return self._apply(lambda x: x.isspace(), dtype=bool) + + def istitle(self): + ''' + Check whether all characters in each string are titlecase. + + Returns + ------- + istitle : array of bool + Array of boolean values with the same shape as the original array. + ''' + return self._apply(lambda x: x.istitle(), dtype=bool) + + def isupper(self): + ''' + Check whether all characters in each string are uppercase. + + Returns + ------- + isupper : array of bool + Array of boolean values with the same shape as the original array. + ''' + return self._apply(lambda x: x.isupper(), dtype=bool) + + def count(self, pat, flags=0): + ''' + Count occurrences of pattern in each string of the array. + + This function is used to count the number of times a particular regex + pattern is repeated in each of the string elements of the + :class:`~xarray.DatArray`. + + Parameters + ---------- + pat : str + Valid regular expression. + flags : int, default 0, meaning no flags + Flags for the `re` module. For a complete list, `see here + `_. + + Returns + ------- + counts : array of int + ''' + pat = self._obj.dtype.type(pat) + regex = re.compile(pat, flags=flags) + f = lambda x: len(regex.findall(x)) + return self._apply(f, dtype=int) + + def startswith(self, pat): + ''' + Test if the start of each string element matches a pattern. + + Parameters + ---------- + pat : str + Character sequence. Regular expressions are not accepted. + + Returns + ------- + startswith : array of bool + An array of booleans indicating whether the given pattern matches + the start of each string element. + ''' + pat = self._obj.dtype.type(pat) + f = lambda x: x.startswith(pat) + return self._apply(f, dtype=bool) + + def endswith(self, pat): + ''' + Test if the end of each string element matches a pattern. + + Parameters + ---------- + pat : str + Character sequence. Regular expressions are not accepted. + + Returns + ------- + endswith : array of bool + A Series of booleans indicating whether the given pattern matches + the end of each string element. + ''' + pat = self._obj.dtype.type(pat) + f = lambda x: x.endswith(pat) + return self._apply(f, dtype=bool) + + def pad(self, width, side='left', fillchar=' '): + ''' + Pad strings in the array up to width. + + Parameters + ---------- + width : int + Minimum width of resulting string; additional characters will be + filled with character defined in `fillchar`. + side : {'left', 'right', 'both'}, default 'left' + Side from which to fill resulting string. + fillchar : str, default ' ' + Additional character for filling, default is whitespace. + + Returns + ------- + filled : same type as values + Array with a minimum number of char in each element. + ''' + width = int(width) + fillchar = self._obj.dtype.type(fillchar) + if len(fillchar) != 1: + raise TypeError('fillchar must be a character, not str') + + if side == 'left': + f = lambda s: s.rjust(width, fillchar) + elif side == 'right': + f = lambda s: s.ljust(width, fillchar) + elif side == 'both': + f = lambda s: s.center(width, fillchar) + else: # pragma: no cover + raise ValueError('Invalid side') + + return self._apply(f) + + def center(self, width, fillchar=' '): + ''' + Filling left and right side of strings in the array with an + additional character. + + Parameters + ---------- + width : int + Minimum width of resulting string; additional characters will be + filled with ``fillchar`` + fillchar : str + Additional character for filling, default is whitespace + + Returns + ------- + filled : same type as values + ''' + return self.pad(width, side='both', fillchar=fillchar) + + def ljust(self, width, fillchar=' '): + ''' + Filling right side of strings in the array with an additional + character. + + Parameters + ---------- + width : int + Minimum width of resulting string; additional characters will be + filled with ``fillchar`` + fillchar : str + Additional character for filling, default is whitespace + + Returns + ------- + filled : same type as values + ''' + return self.pad(width, side='right', fillchar=fillchar) + + def rjust(self, width, fillchar=' '): + ''' + Filling left side of strings in the array with an additional character. + + Parameters + ---------- + width : int + Minimum width of resulting string; additional characters will be + filled with ``fillchar`` + fillchar : str + Additional character for filling, default is whitespace + + Returns + ------- + filled : same type as values + ''' + return self.pad(width, side='left', fillchar=fillchar) + + def zfill(self, width): + ''' + Pad strings in the array by prepending '0' characters. + + Strings in the array are padded with '0' characters on the + left of the string to reach a total string length `width`. Strings + in the array with length greater or equal to `width` are unchanged. + + Parameters + ---------- + width : int + Minimum length of resulting string; strings with length less + than `width` be prepended with '0' characters. + + Returns + ------- + filled : same type as values + ''' + return self.pad(width, side='left', fillchar='0') + + def contains(self, pat, case=True, flags=0, regex=True): + ''' + Test if pattern or regex is contained within a string of the array. + + Return boolean array based on whether a given pattern or regex is + contained within a string of the array. + + Parameters + ---------- + pat : str + Character sequence or regular expression. + case : bool, default True + If True, case sensitive. + flags : int, default 0 (no flags) + Flags to pass through to the re module, e.g. re.IGNORECASE. + regex : bool, default True + If True, assumes the pat is a regular expression. + If False, treats the pat as a literal string. + + Returns + ------- + contains : array of bool + An array of boolean values indicating whether the + given pattern is contained within the string of each element + of the array. + ''' + pat = self._obj.dtype.type(pat) + if regex: + if not case: + flags |= re.IGNORECASE + + regex = re.compile(pat, flags=flags) + + if regex.groups > 0: # pragma: no cover + raise ValueError("This pattern has match groups.") + + f = lambda x: bool(regex.search(x)) + else: + if case: + f = lambda x: pat in x + else: + uppered = self._obj.str.upper() + return uppered.str.contains(pat.upper(), regex=False) + + return self._apply(f, dtype=bool) + + def match(self, pat, case=True, flags=0): + ''' + Determine if each string matches a regular expression. + + Parameters + ---------- + pat : string + Character sequence or regular expression + case : boolean, default True + If True, case sensitive + flags : int, default 0 (no flags) + re module flags, e.g. re.IGNORECASE + + Returns + ------- + matched : array of bool + ''' + if not case: + flags |= re.IGNORECASE + + pat = self._obj.dtype.type(pat) + regex = re.compile(pat, flags=flags) + f = lambda x: bool(regex.match(x)) + return self._apply(f, dtype=bool) + + def strip(self, to_strip=None, side='both'): + ''' + Remove leading and trailing characters. + + Strip whitespaces (including newlines) or a set of specified characters + from each string in the array from left and/or right sides. + + Parameters + ---------- + to_strip : str or None, default None + Specifying the set of characters to be removed. + All combinations of this set of characters will be stripped. + If None then whitespaces are removed. + side : {'left', 'right', 'both'}, default 'left' + Side from which to strip. + + Returns + ------- + stripped : same type as values + ''' + if to_strip is not None: + to_strip = self._obj.dtype.type(to_strip) + + if side == 'both': + f = lambda x: x.strip(to_strip) + elif side == 'left': + f = lambda x: x.lstrip(to_strip) + elif side == 'right': + f = lambda x: x.rstrip(to_strip) + else: # pragma: no cover + raise ValueError('Invalid side') + + return self._apply(f) + + def lstrip(self, to_strip=None): + ''' + Remove leading and trailing characters. + + Strip whitespaces (including newlines) or a set of specified characters + from each string in the array from the left side. + + Parameters + ---------- + to_strip : str or None, default None + Specifying the set of characters to be removed. + All combinations of this set of characters will be stripped. + If None then whitespaces are removed. + + Returns + ------- + stripped : same type as values + ''' + return self.strip(to_strip, side='left') + + def rstrip(self, to_strip=None): + ''' + Remove leading and trailing characters. + + Strip whitespaces (including newlines) or a set of specified characters + from each string in the array from the right side. + + Parameters + ---------- + to_strip : str or None, default None + Specifying the set of characters to be removed. + All combinations of this set of characters will be stripped. + If None then whitespaces are removed. + + Returns + ------- + stripped : same type as values + ''' + return self.strip(to_strip, side='right') + + def wrap(self, width, **kwargs): + ''' + Wrap long strings in the array to be formatted in paragraphs with + length less than a given width. + + This method has the same keyword parameters and defaults as + :class:`textwrap.TextWrapper`. + + Parameters + ---------- + width : int + Maximum line-width + expand_tabs : bool, optional + If true, tab characters will be expanded to spaces (default: True) + replace_whitespace : bool, optional + If true, each whitespace character (as defined by + string.whitespace) remaining after tab expansion will be replaced + by a single space (default: True) + drop_whitespace : bool, optional + If true, whitespace that, after wrapping, happens to end up at the + beginning or end of a line is dropped (default: True) + break_long_words : bool, optional + If true, then words longer than width will be broken in order to + ensure that no lines are longer than width. If it is false, long + words will not be broken, and some lines may be longer than width. + (default: True) + break_on_hyphens : bool, optional + If true, wrapping will occur preferably on whitespace and right + after hyphens in compound words, as it is customary in English. If + false, only whitespaces will be considered as potentially good + places for line breaks, but you need to set break_long_words to + false if you want truly insecable words. (default: True) + + Returns + ------- + wrapped : same type as values + ''' + tw = textwrap.TextWrapper(width=width) + f = lambda x: '\n'.join(tw.wrap(x)) + return self._apply(f) + + def translate(self, table): + ''' + Map all characters in the string through the given mapping table. + + Parameters + ---------- + table : dict + A a mapping of Unicode ordinals to Unicode ordinals, strings, + or None. Unmapped characters are left untouched. Characters mapped + to None are deleted. :meth:`str.maketrans` is a helper function for + making translation tables. + + Returns + ------- + translated : same type as values + ''' + f = lambda x: x.translate(table) + return self._apply(f) + + def repeat(self, repeats): + ''' + Duplicate each string in the array. + + Parameters + ---------- + repeats : int + Number of repetitions. + + Returns + ------- + repeated : same type as values + Array of repeated string objects. + ''' + f = lambda x: repeats * x + return self._apply(f) + + def find(self, sub, start=0, end=None, side='left'): + ''' + Return lowest or highest indexes in each strings in the array + where the substring is fully contained between [start:end]. + Return -1 on failure. + + Parameters + ---------- + sub : str + Substring being searched + start : int + Left edge index + end : int + Right edge index + side : {'left', 'right'}, default 'left' + Starting side for search. + + Returns + ------- + found : array of integer values + ''' + sub = self._obj.dtype.type(sub) + + if side == 'left': + method = 'find' + elif side == 'right': + method = 'rfind' + else: # pragma: no cover + raise ValueError('Invalid side') + + if end is None: + f = lambda x: getattr(x, method)(sub, start) + else: + f = lambda x: getattr(x, method)(sub, start, end) + + return self._apply(f, dtype=int) + + def rfind(self, sub, start=0, end=None): + ''' + Return highest indexes in each strings in the array + where the substring is fully contained between [start:end]. + Return -1 on failure. + + Parameters + ---------- + sub : str + Substring being searched + start : int + Left edge index + end : int + Right edge index + + Returns + ------- + found : array of integer values + ''' + return self.find(sub, start=start, end=end, side='right') + + def index(self, sub, start=0, end=None, side='left'): + ''' + Return lowest or highest indexes in each strings where the substring is + fully contained between [start:end]. This is the same as + ``str.find`` except instead of returning -1, it raises a ValueError + when the substring is not found. + + Parameters + ---------- + sub : str + Substring being searched + start : int + Left edge index + end : int + Right edge index + side : {'left', 'right'}, default 'left' + Starting side for search. + + Returns + ------- + found : array of integer values + ''' + sub = self._obj.dtype.type(sub) + + if side == 'left': + method = 'index' + elif side == 'right': + method = 'rindex' + else: # pragma: no cover + raise ValueError('Invalid side') + + if end is None: + f = lambda x: getattr(x, method)(sub, start) + else: + f = lambda x: getattr(x, method)(sub, start, end) + + return self._apply(f, dtype=int) + + def rindex(self, sub, start=0, end=None): + ''' + Return highest indexes in each strings where the substring is + fully contained between [start:end]. This is the same as + ``str.rfind`` except instead of returning -1, it raises a ValueError + when the substring is not found. + + Parameters + ---------- + sub : str + Substring being searched + start : int + Left edge index + end : int + Right edge index + + Returns + ------- + found : array of integer values + ''' + return self.index(sub, start=start, end=end, side='right') + + def replace(self, pat, repl, n=-1, case=None, flags=0, regex=True): + ''' + Replace occurrences of pattern/regex in the array with some string. + + Parameters + ---------- + pat : string or compiled regex + String can be a character sequence or regular expression. + + repl : string or callable + Replacement string or a callable. The callable is passed the regex + match object and must return a replacement string to be used. + See :func:`re.sub`. + + n : int, default -1 (all) + Number of replacements to make from start + case : boolean, default None + - If True, case sensitive (the default if `pat` is a string) + - Set to False for case insensitive + - Cannot be set if `pat` is a compiled regex + flags : int, default 0 (no flags) + - re module flags, e.g. re.IGNORECASE + - Cannot be set if `pat` is a compiled regex + regex : boolean, default True + - If True, assumes the passed-in pattern is a regular expression. + - If False, treats the pattern as a literal string + - Cannot be set to False if `pat` is a compiled regex or `repl` is + a callable. + + Returns + ------- + replaced : same type as values + A copy of the object with all matching occurrences of `pat` + replaced by `repl`. + ''' + if not (_is_str_like(repl) or callable(repl)): # pragma: no cover + raise TypeError("repl must be a string or callable") + + if _is_str_like(pat): + pat = self._obj.dtype.type(pat) + + if _is_str_like(repl): + repl = self._obj.dtype.type(repl) + + is_compiled_re = isinstance(pat, type(re.compile(''))) + if regex: + if is_compiled_re: + if (case is not None) or (flags != 0): + raise ValueError("case and flags cannot be set" + " when pat is a compiled regex") + else: + # not a compiled regex + # set default case + if case is None: + case = True + + # add case flag, if provided + if case is False: + flags |= re.IGNORECASE + if is_compiled_re or len(pat) > 1 or flags or callable(repl): + n = n if n >= 0 else 0 + compiled = re.compile(pat, flags=flags) + f = lambda x: compiled.sub(repl=repl, string=x, count=n) + else: + f = lambda x: x.replace(pat, repl, n) + else: + if is_compiled_re: + raise ValueError("Cannot use a compiled regex as replacement " + "pattern with regex=False") + if callable(repl): + raise ValueError("Cannot use a callable replacement when " + "regex=False") + f = lambda x: x.replace(pat, repl, n) + return self._apply(f) + + def decode(self, encoding, errors='strict'): + ''' + Decode character string in the array using indicated encoding. + + Parameters + ---------- + encoding : str + errors : str, optional + + Returns + ------- + decoded : same type as values + ''' + if encoding in _cpython_optimized_decoders: + f = lambda x: x.decode(encoding, errors) + else: + decoder = codecs.getdecoder(encoding) + f = lambda x: decoder(x, errors)[0] + return self._apply(f, dtype=np.str_) + + def encode(self, encoding, errors='strict'): + ''' + Encode character string in the array using indicated encoding. + + Parameters + ---------- + encoding : str + errors : str, optional + + Returns + ------- + encoded : same type as values + ''' + if encoding in _cpython_optimized_encoders: + f = lambda x: x.encode(encoding, errors) + else: + encoder = codecs.getencoder(encoding) + f = lambda x: encoder(x, errors)[0] + return self._apply(f, dtype=np.bytes_) diff --git a/xarray/core/dataarray.py b/xarray/core/dataarray.py index cab5612dfe1..a492635dc67 100644 --- a/xarray/core/dataarray.py +++ b/xarray/core/dataarray.py @@ -9,7 +9,8 @@ from ..plot.plot import _PlotMethods from . import ( computation, dtypes, groupby, indexing, ops, resample, rolling, utils) -from .accessors import DatetimeAccessor +from .accessor_dt import DatetimeAccessor +from .accessor_str import StringAccessor from .alignment import align, reindex_like_indexers from .common import AbstractArray, DataWithCoords from .coordinates import ( @@ -162,6 +163,7 @@ class DataArray(AbstractArray, DataWithCoords): _resample_cls = resample.DataArrayResample dt = property(DatetimeAccessor) + str = property(StringAccessor) def __init__(self, data, coords=None, dims=None, name=None, attrs=None, encoding=None, indexes=None, fastpath=False): diff --git a/xarray/tests/test_accessors.py b/xarray/tests/test_accessor_dt.py similarity index 100% rename from xarray/tests/test_accessors.py rename to xarray/tests/test_accessor_dt.py diff --git a/xarray/tests/test_accessor_str.py b/xarray/tests/test_accessor_str.py new file mode 100644 index 00000000000..26d5e385df3 --- /dev/null +++ b/xarray/tests/test_accessor_str.py @@ -0,0 +1,659 @@ +# Tests for the `str` accessor are derived from the original +# pandas string accessor tests. + +# For reference, here is a copy of the pandas copyright notice: + +# (c) 2011-2012, Lambda Foundry, Inc. and PyData Development Team +# All rights reserved. + +# Copyright (c) 2008-2011 AQR Capital Management, LLC +# All rights reserved. + +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are +# met: + +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. + +# * Redistributions in binary form must reproduce the above +# copyright notice, this list of conditions and the following +# disclaimer in the documentation and/or other materials provided +# with the distribution. + +# * Neither the name of the copyright holder nor the names of any +# contributors may be used to endorse or promote products derived +# from this software without specific prior written permission. + +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDER AND CONTRIBUTORS +# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import re + +import pytest +import numpy as np +import xarray as xr + +from . import ( + assert_array_equal, assert_equal, has_dask, raises_regex, requires_dask) + + +@pytest.fixture(params=[np.str_, np.bytes_]) +def dtype(request): + return request.param + + +@requires_dask +def test_dask(): + import dask.array as da + arr = da.from_array(['a', 'b', 'c']) + xarr = xr.DataArray(arr) + + result = xarr.str.len().compute() + expected = xr.DataArray([1, 1, 1]) + assert_equal(result, expected) + + +def test_count(dtype): + values = xr.DataArray(['foo', 'foofoo', 'foooofooofommmfoo']).astype(dtype) + result = values.str.count('f[o]+') + expected = xr.DataArray([1, 2, 4]) + assert_equal(result, expected) + + +def test_contains(dtype): + values = xr.DataArray(['Foo', 'xYz', 'fOOomMm__fOo', 'MMM_']).astype(dtype) + # case insensitive using regex + result = values.str.contains('FOO|mmm', case=False) + expected = xr.DataArray([True, False, True, True]) + assert_equal(result, expected) + # case insensitive without regex + result = values.str.contains('foo', regex=False, case=False) + expected = xr.DataArray([True, False, True, False]) + assert_equal(result, expected) + + +def test_starts_ends_with(dtype): + values = xr.DataArray( + ['om', 'foo_nom', 'nom', 'bar_foo', 'foo']).astype(dtype) + result = values.str.startswith('foo') + expected = xr.DataArray([False, True, False, False, True]) + assert_equal(result, expected) + result = values.str.endswith('foo') + expected = xr.DataArray([False, False, False, True, True]) + assert_equal(result, expected) + + +def test_case(dtype): + da = xr.DataArray(['SOme word']).astype(dtype) + capitalized = xr.DataArray(['Some word']).astype(dtype) + lowered = xr.DataArray(['some word']).astype(dtype) + swapped = xr.DataArray(['soME WORD']).astype(dtype) + titled = xr.DataArray(['Some Word']).astype(dtype) + uppered = xr.DataArray(['SOME WORD']).astype(dtype) + assert_equal(da.str.capitalize(), capitalized) + assert_equal(da.str.lower(), lowered) + assert_equal(da.str.swapcase(), swapped) + assert_equal(da.str.title(), titled) + assert_equal(da.str.upper(), uppered) + + +def test_replace(dtype): + values = xr.DataArray(['fooBAD__barBAD']).astype(dtype) + result = values.str.replace('BAD[_]*', '') + expected = xr.DataArray(['foobar']).astype(dtype) + assert_equal(result, expected) + + result = values.str.replace('BAD[_]*', '', n=1) + expected = xr.DataArray(['foobarBAD']).astype(dtype) + assert_equal(result, expected) + + s = xr.DataArray(['A', 'B', 'C', 'Aaba', 'Baca', '', + 'CABA', 'dog', 'cat']).astype(dtype) + result = s.str.replace('A', 'YYY') + expected = xr.DataArray(['YYY', 'B', 'C', 'YYYaba', 'Baca', '', 'CYYYBYYY', + 'dog', 'cat']).astype(dtype) + assert_equal(result, expected) + + result = s.str.replace('A', 'YYY', case=False) + expected = xr.DataArray(['YYY', 'B', 'C', 'YYYYYYbYYY', 'BYYYcYYY', + '', 'CYYYBYYY', 'dog', 'cYYYt']).astype(dtype) + assert_equal(result, expected) + + result = s.str.replace('^.a|dog', 'XX-XX ', case=False) + expected = xr.DataArray(['A', 'B', 'C', 'XX-XX ba', 'XX-XX ca', '', + 'XX-XX BA', 'XX-XX ', 'XX-XX t']).astype(dtype) + assert_equal(result, expected) + + +def test_replace_callable(): + values = xr.DataArray(['fooBAD__barBAD']) + # test with callable + repl = lambda m: m.group(0).swapcase() + result = values.str.replace('[a-z][A-Z]{2}', repl, n=2) + exp = xr.DataArray(['foObaD__baRbaD']) + assert_equal(result, exp) + # test regex named groups + values = xr.DataArray(['Foo Bar Baz']) + pat = r"(?P\w+) (?P\w+) (?P\w+)" + repl = lambda m: m.group('middle').swapcase() + result = values.str.replace(pat, repl) + exp = xr.DataArray(['bAR']) + assert_equal(result, exp) + + +def test_replace_unicode(): + # flags + unicode + values = xr.DataArray([b"abcd,\xc3\xa0".decode("utf-8")]) + expected = xr.DataArray([b"abcd, \xc3\xa0".decode("utf-8")]) + pat = re.compile(r"(?<=\w),(?=\w)", flags=re.UNICODE) + result = values.str.replace(pat, ", ") + assert_equal(result, expected) + + +def test_replace_compiled_regex(dtype): + values = xr.DataArray(['fooBAD__barBAD']).astype(dtype) + # test with compiled regex + pat = re.compile(dtype('BAD[_]*')) + result = values.str.replace(pat, '') + expected = xr.DataArray(['foobar']).astype(dtype) + assert_equal(result, expected) + + result = values.str.replace(pat, '', n=1) + expected = xr.DataArray(['foobarBAD']).astype(dtype) + assert_equal(result, expected) + + # case and flags provided to str.replace will have no effect + # and will produce warnings + values = xr.DataArray(['fooBAD__barBAD__bad']).astype(dtype) + pat = re.compile(dtype('BAD[_]*')) + + with pytest.raises(ValueError, match="case and flags cannot be"): + result = values.str.replace(pat, '', flags=re.IGNORECASE) + + with pytest.raises(ValueError, match="case and flags cannot be"): + result = values.str.replace(pat, '', case=False) + + with pytest.raises(ValueError, match="case and flags cannot be"): + result = values.str.replace(pat, '', case=True) + + # test with callable + values = xr.DataArray(['fooBAD__barBAD']).astype(dtype) + repl = lambda m: m.group(0).swapcase() + pat = re.compile(dtype('[a-z][A-Z]{2}')) + result = values.str.replace(pat, repl, n=2) + expected = xr.DataArray(['foObaD__baRbaD']).astype(dtype) + assert_equal(result, expected) + + +def test_replace_literal(dtype): + # GH16808 literal replace (regex=False vs regex=True) + values = xr.DataArray(['f.o', 'foo']).astype(dtype) + expected = xr.DataArray(['bao', 'bao']).astype(dtype) + result = values.str.replace('f.', 'ba') + assert_equal(result, expected) + + expected = xr.DataArray(['bao', 'foo']).astype(dtype) + result = values.str.replace('f.', 'ba', regex=False) + assert_equal(result, expected) + + # Cannot do a literal replace if given a callable repl or compiled + # pattern + callable_repl = lambda m: m.group(0).swapcase() + compiled_pat = re.compile('[a-z][A-Z]{2}') + + msg = "Cannot use a callable replacement when regex=False" + with pytest.raises(ValueError, match=msg): + values.str.replace('abc', callable_repl, regex=False) + + msg = "Cannot use a compiled regex as replacement pattern with regex=False" + with pytest.raises(ValueError, match=msg): + values.str.replace(compiled_pat, '', regex=False) + + +def test_repeat(dtype): + values = xr.DataArray(['a', 'b', 'c', 'd']).astype(dtype) + result = values.str.repeat(3) + expected = xr.DataArray(['aaa', 'bbb', 'ccc', 'ddd']).astype(dtype) + assert_equal(result, expected) + + +def test_match(dtype): + # New match behavior introduced in 0.13 + values = xr.DataArray(['fooBAD__barBAD', 'foo']).astype(dtype) + result = values.str.match('.*(BAD[_]+).*(BAD)') + expected = xr.DataArray([True, False]) + assert_equal(result, expected) + + values = xr.DataArray(['fooBAD__barBAD', 'foo']).astype(dtype) + result = values.str.match('.*BAD[_]+.*BAD') + expected = xr.DataArray([True, False]) + assert_equal(result, expected) + + +def test_empty_str_methods(): + empty = xr.DataArray(np.empty(shape=(0,), dtype='U')) + empty_str = empty + empty_int = xr.DataArray(np.empty(shape=(0,), dtype=int)) + empty_bool = xr.DataArray(np.empty(shape=(0,), dtype=bool)) + empty_bytes = xr.DataArray(np.empty(shape=(0,), dtype='S')) + + assert_equal(empty_str, empty.str.title()) + assert_equal(empty_int, empty.str.count('a')) + assert_equal(empty_bool, empty.str.contains('a')) + assert_equal(empty_bool, empty.str.startswith('a')) + assert_equal(empty_bool, empty.str.endswith('a')) + assert_equal(empty_str, empty.str.lower()) + assert_equal(empty_str, empty.str.upper()) + assert_equal(empty_str, empty.str.replace('a', 'b')) + assert_equal(empty_str, empty.str.repeat(3)) + assert_equal(empty_bool, empty.str.match('^a')) + assert_equal(empty_int, empty.str.len()) + assert_equal(empty_int, empty.str.find('a')) + assert_equal(empty_int, empty.str.rfind('a')) + assert_equal(empty_str, empty.str.pad(42)) + assert_equal(empty_str, empty.str.center(42)) + assert_equal(empty_str, empty.str.slice(stop=1)) + assert_equal(empty_str, empty.str.slice(step=1)) + assert_equal(empty_str, empty.str.strip()) + assert_equal(empty_str, empty.str.lstrip()) + assert_equal(empty_str, empty.str.rstrip()) + assert_equal(empty_str, empty.str.wrap(42)) + assert_equal(empty_str, empty.str.get(0)) + assert_equal(empty_str, empty_bytes.str.decode('ascii')) + assert_equal(empty_bytes, empty.str.encode('ascii')) + assert_equal(empty_str, empty.str.isalnum()) + assert_equal(empty_str, empty.str.isalpha()) + assert_equal(empty_str, empty.str.isdigit()) + assert_equal(empty_str, empty.str.isspace()) + assert_equal(empty_str, empty.str.islower()) + assert_equal(empty_str, empty.str.isupper()) + assert_equal(empty_str, empty.str.istitle()) + assert_equal(empty_str, empty.str.isnumeric()) + assert_equal(empty_str, empty.str.isdecimal()) + assert_equal(empty_str, empty.str.capitalize()) + assert_equal(empty_str, empty.str.swapcase()) + table = str.maketrans('a', 'b') + assert_equal(empty_str, empty.str.translate(table)) + + +def test_ismethods(dtype): + values = ['A', 'b', 'Xy', '4', '3A', '', 'TT', '55', '-', ' '] + str_s = xr.DataArray(values).astype(dtype) + alnum_e = [True, True, True, True, True, False, True, True, False, False] + alpha_e = [True, True, True, False, False, False, True, False, False, + False] + digit_e = [False, False, False, True, False, False, False, True, False, + False] + space_e = [False, False, False, False, False, False, False, False, + False, True] + lower_e = [False, True, False, False, False, False, False, False, + False, False] + upper_e = [True, False, False, False, True, False, True, False, False, + False] + title_e = [True, False, True, False, True, False, False, False, False, + False] + + assert_equal(str_s.str.isalnum(), xr.DataArray(alnum_e)) + assert_equal(str_s.str.isalpha(), xr.DataArray(alpha_e)) + assert_equal(str_s.str.isdigit(), xr.DataArray(digit_e)) + assert_equal(str_s.str.isspace(), xr.DataArray(space_e)) + assert_equal(str_s.str.islower(), xr.DataArray(lower_e)) + assert_equal(str_s.str.isupper(), xr.DataArray(upper_e)) + assert_equal(str_s.str.istitle(), xr.DataArray(title_e)) + + +def test_isnumeric(): + # 0x00bc: ¼ VULGAR FRACTION ONE QUARTER + # 0x2605: ★ not number + # 0x1378: ፸ ETHIOPIC NUMBER SEVENTY + # 0xFF13: 3 Em 3 + values = ['A', '3', '¼', '★', '፸', '3', 'four'] + s = xr.DataArray(values) + numeric_e = [False, True, True, False, True, True, False] + decimal_e = [False, True, False, False, False, True, False] + assert_equal(s.str.isnumeric(), xr.DataArray(numeric_e)) + assert_equal(s.str.isdecimal(), xr.DataArray(decimal_e)) + + +def test_len(dtype): + values = ['foo', 'fooo', 'fooooo', 'fooooooo'] + result = xr.DataArray(values).astype(dtype).str.len() + expected = xr.DataArray([len(x) for x in values]) + assert_equal(result, expected) + + +def test_find(dtype): + values = xr.DataArray(['ABCDEFG', 'BCDEFEF', 'DEFGHIJEF', 'EFGHEF', 'XXX']) + values = values.astype(dtype) + result = values.str.find('EF') + assert_equal(result, xr.DataArray([4, 3, 1, 0, -1])) + expected = xr.DataArray([v.find(dtype('EF')) for v in values.values]) + assert_equal(result, expected) + + result = values.str.rfind('EF') + assert_equal(result, xr.DataArray([4, 5, 7, 4, -1])) + expected = xr.DataArray([v.rfind(dtype('EF')) for v in values.values]) + assert_equal(result, expected) + + result = values.str.find('EF', 3) + assert_equal(result, xr.DataArray([4, 3, 7, 4, -1])) + expected = xr.DataArray([v.find(dtype('EF'), 3) for v in values.values]) + assert_equal(result, expected) + + result = values.str.rfind('EF', 3) + assert_equal(result, xr.DataArray([4, 5, 7, 4, -1])) + expected = xr.DataArray([v.rfind(dtype('EF'), 3) for v in values.values]) + assert_equal(result, expected) + + result = values.str.find('EF', 3, 6) + assert_equal(result, xr.DataArray([4, 3, -1, 4, -1])) + expected = xr.DataArray([v.find(dtype('EF'), 3, 6) for v in values.values]) + assert_equal(result, expected) + + result = values.str.rfind('EF', 3, 6) + assert_equal(result, xr.DataArray([4, 3, -1, 4, -1])) + xp = xr.DataArray([v.rfind(dtype('EF'), 3, 6) for v in values.values]) + assert_equal(result, xp) + + +def test_index(dtype): + s = xr.DataArray(['ABCDEFG', 'BCDEFEF', 'DEFGHIJEF', + 'EFGHEF']).astype(dtype) + + result = s.str.index('EF') + assert_equal(result, xr.DataArray([4, 3, 1, 0])) + + result = s.str.rindex('EF') + assert_equal(result, xr.DataArray([4, 5, 7, 4])) + + result = s.str.index('EF', 3) + assert_equal(result, xr.DataArray([4, 3, 7, 4])) + + result = s.str.rindex('EF', 3) + assert_equal(result, xr.DataArray([4, 5, 7, 4])) + + result = s.str.index('E', 4, 8) + assert_equal(result, xr.DataArray([4, 5, 7, 4])) + + result = s.str.rindex('E', 0, 5) + assert_equal(result, xr.DataArray([4, 3, 1, 4])) + + with pytest.raises(ValueError): + result = s.str.index('DE') + + +def test_pad(dtype): + values = xr.DataArray(['a', 'b', 'c', 'eeeee']).astype(dtype) + + result = values.str.pad(5, side='left') + expected = xr.DataArray([' a', ' b', ' c', 'eeeee']).astype(dtype) + assert_equal(result, expected) + + result = values.str.pad(5, side='right') + expected = xr.DataArray(['a ', 'b ', 'c ', 'eeeee']).astype(dtype) + assert_equal(result, expected) + + result = values.str.pad(5, side='both') + expected = xr.DataArray([' a ', ' b ', ' c ', 'eeeee']).astype(dtype) + assert_equal(result, expected) + + +def test_pad_fillchar(dtype): + values = xr.DataArray(['a', 'b', 'c', 'eeeee']).astype(dtype) + + result = values.str.pad(5, side='left', fillchar='X') + expected = xr.DataArray(['XXXXa', 'XXXXb', 'XXXXc', 'eeeee']).astype(dtype) + assert_equal(result, expected) + + result = values.str.pad(5, side='right', fillchar='X') + expected = xr.DataArray(['aXXXX', 'bXXXX', 'cXXXX', 'eeeee']).astype(dtype) + assert_equal(result, expected) + + result = values.str.pad(5, side='both', fillchar='X') + expected = xr.DataArray(['XXaXX', 'XXbXX', 'XXcXX', 'eeeee']).astype(dtype) + assert_equal(result, expected) + + msg = "fillchar must be a character, not str" + with pytest.raises(TypeError, match=msg): + result = values.str.pad(5, fillchar='XY') + + +def test_translate(): + values = xr.DataArray(['abcdefg', 'abcc', 'cdddfg', 'cdefggg']) + table = str.maketrans('abc', 'cde') + result = values.str.translate(table) + expected = xr.DataArray(['cdedefg', 'cdee', 'edddfg', 'edefggg']) + assert_equal(result, expected) + + +def test_center_ljust_rjust(dtype): + values = xr.DataArray(['a', 'b', 'c', 'eeeee']).astype(dtype) + + result = values.str.center(5) + expected = xr.DataArray([' a ', ' b ', ' c ', 'eeeee']).astype(dtype) + assert_equal(result, expected) + + result = values.str.ljust(5) + expected = xr.DataArray(['a ', 'b ', 'c ', 'eeeee']).astype(dtype) + assert_equal(result, expected) + + result = values.str.rjust(5) + expected = xr.DataArray([' a', ' b', ' c', 'eeeee']).astype(dtype) + assert_equal(result, expected) + + +def test_center_ljust_rjust_fillchar(dtype): + values = xr.DataArray(['a', 'bb', 'cccc', 'ddddd', 'eeeeee']).astype(dtype) + result = values.str.center(5, fillchar='X') + expected = xr.DataArray(['XXaXX', 'XXbbX', 'Xcccc', 'ddddd', 'eeeeee']) + assert_equal(result, expected.astype(dtype)) + + result = values.str.ljust(5, fillchar='X') + expected = xr.DataArray(['aXXXX', 'bbXXX', 'ccccX', 'ddddd', 'eeeeee']) + assert_equal(result, expected.astype(dtype)) + + result = values.str.rjust(5, fillchar='X') + expected = xr.DataArray(['XXXXa', 'XXXbb', 'Xcccc', 'ddddd', 'eeeeee']) + assert_equal(result, expected.astype(dtype)) + + # If fillchar is not a charatter, normal str raises TypeError + # 'aaa'.ljust(5, 'XY') + # TypeError: must be char, not str + template = "fillchar must be a character, not {dtype}" + + with pytest.raises(TypeError, match=template.format(dtype="str")): + values.str.center(5, fillchar='XY') + + with pytest.raises(TypeError, match=template.format(dtype="str")): + values.str.ljust(5, fillchar='XY') + + with pytest.raises(TypeError, match=template.format(dtype="str")): + values.str.rjust(5, fillchar='XY') + + +def test_zfill(dtype): + values = xr.DataArray(['1', '22', 'aaa', '333', '45678']).astype(dtype) + + result = values.str.zfill(5) + expected = xr.DataArray(['00001', '00022', '00aaa', '00333', '45678']) + assert_equal(result, expected.astype(dtype)) + + result = values.str.zfill(3) + expected = xr.DataArray(['001', '022', 'aaa', '333', '45678']) + assert_equal(result, expected.astype(dtype)) + + +def test_slice(dtype): + arr = xr.DataArray(['aafootwo', 'aabartwo', 'aabazqux']).astype(dtype) + + result = arr.str.slice(2, 5) + exp = xr.DataArray(['foo', 'bar', 'baz']).astype(dtype) + assert_equal(result, exp) + + for start, stop, step in [(0, 3, -1), (None, None, -1), + (3, 10, 2), (3, 0, -1)]: + try: + result = arr.str[start:stop:step] + expected = xr.DataArray([s[start:stop:step] for s in arr.values]) + assert_equal(result, expected.astype(dtype)) + except IndexError: + print('failed on %s:%s:%s' % (start, stop, step)) + raise + + +def test_slice_replace(dtype): + da = lambda x: xr.DataArray(x).astype(dtype) + values = da(['short', 'a bit longer', 'evenlongerthanthat', '']) + + expected = da(['shrt', 'a it longer', 'evnlongerthanthat', '']) + result = values.str.slice_replace(2, 3) + assert_equal(result, expected) + + expected = da(['shzrt', 'a zit longer', 'evznlongerthanthat', 'z']) + result = values.str.slice_replace(2, 3, 'z') + assert_equal(result, expected) + + expected = da(['shzort', 'a zbit longer', 'evzenlongerthanthat', 'z']) + result = values.str.slice_replace(2, 2, 'z') + assert_equal(result, expected) + + expected = da(['shzort', 'a zbit longer', 'evzenlongerthanthat', 'z']) + result = values.str.slice_replace(2, 1, 'z') + assert_equal(result, expected) + + expected = da(['shorz', 'a bit longez', 'evenlongerthanthaz', 'z']) + result = values.str.slice_replace(-1, None, 'z') + assert_equal(result, expected) + + expected = da(['zrt', 'zer', 'zat', 'z']) + result = values.str.slice_replace(None, -2, 'z') + assert_equal(result, expected) + + expected = da(['shortz', 'a bit znger', 'evenlozerthanthat', 'z']) + result = values.str.slice_replace(6, 8, 'z') + assert_equal(result, expected) + + expected = da(['zrt', 'a zit longer', 'evenlongzerthanthat', 'z']) + result = values.str.slice_replace(-10, 3, 'z') + assert_equal(result, expected) + + +def test_strip_lstrip_rstrip(dtype): + values = xr.DataArray([' aa ', ' bb \n', 'cc ']).astype(dtype) + + result = values.str.strip() + expected = xr.DataArray(['aa', 'bb', 'cc']).astype(dtype) + assert_equal(result, expected) + + result = values.str.lstrip() + expected = xr.DataArray(['aa ', 'bb \n', 'cc ']).astype(dtype) + assert_equal(result, expected) + + result = values.str.rstrip() + expected = xr.DataArray([' aa', ' bb', 'cc']).astype(dtype) + assert_equal(result, expected) + + +def test_strip_lstrip_rstrip_args(dtype): + values = xr.DataArray(['xxABCxx', 'xx BNSD', 'LDFJH xx']).astype(dtype) + + rs = values.str.strip('x') + xp = xr.DataArray(['ABC', ' BNSD', 'LDFJH ']).astype(dtype) + assert_equal(rs, xp) + + rs = values.str.lstrip('x') + xp = xr.DataArray(['ABCxx', ' BNSD', 'LDFJH xx']).astype(dtype) + assert_equal(rs, xp) + + rs = values.str.rstrip('x') + xp = xr.DataArray(['xxABC', 'xx BNSD', 'LDFJH ']).astype(dtype) + assert_equal(rs, xp) + + +def test_wrap(): + # test values are: two words less than width, two words equal to width, + # two words greater than width, one word less than width, one word + # equal to width, one word greater than width, multiple tokens with + # trailing whitespace equal to width + values = xr.DataArray(['hello world', 'hello world!', 'hello world!!', + 'abcdefabcde', 'abcdefabcdef', 'abcdefabcdefa', + 'ab ab ab ab ', 'ab ab ab ab a', '\t']) + + # expected values + xp = xr.DataArray(['hello world', 'hello world!', 'hello\nworld!!', + 'abcdefabcde', 'abcdefabcdef', 'abcdefabcdef\na', + 'ab ab ab ab', 'ab ab ab ab\na', '']) + + rs = values.str.wrap(12, break_long_words=True) + assert_equal(rs, xp) + + # test with pre and post whitespace (non-unicode), NaN, and non-ascii + # Unicode + values = xr.DataArray([' pre ', '\xac\u20ac\U00008000 abadcafe']) + xp = xr.DataArray([' pre', '\xac\u20ac\U00008000 ab\nadcafe']) + rs = values.str.wrap(6) + assert_equal(rs, xp) + + +def test_get(dtype): + values = xr.DataArray(['a_b_c', 'c_d_e', 'f_g_h']).astype(dtype) + + result = values.str[2] + expected = xr.DataArray(['b', 'd', 'g']).astype(dtype) + assert_equal(result, expected) + + # bounds testing + values = xr.DataArray(['1_2_3_4_5', '6_7_8_9_10', '11_12']).astype(dtype) + + # positive index + result = values.str[5] + expected = xr.DataArray(['_', '_', '']).astype(dtype) + assert_equal(result, expected) + + # negative index + result = values.str[-6] + expected = xr.DataArray(['_', '8', '']).astype(dtype) + assert_equal(result, expected) + + +def test_encode_decode(): + data = xr.DataArray(['a', 'b', 'a\xe4']) + encoded = data.str.encode('utf-8') + decoded = encoded.str.decode('utf-8') + assert_equal(data, decoded) + + +def test_encode_decode_errors(): + encodeBase = xr.DataArray(['a', 'b', 'a\x9d']) + + msg = (r"'charmap' codec can't encode character '\\x9d' in position 1:" + " character maps to ") + with pytest.raises(UnicodeEncodeError, match=msg): + encodeBase.str.encode('cp1252') + + f = lambda x: x.encode('cp1252', 'ignore') + result = encodeBase.str.encode('cp1252', 'ignore') + expected = xr.DataArray([f(x) for x in encodeBase.values.tolist()]) + assert_equal(result, expected) + + decodeBase = xr.DataArray([b'a', b'b', b'a\x9d']) + + msg = ("'charmap' codec can't decode byte 0x9d in position 1:" + " character maps to ") + with pytest.raises(UnicodeDecodeError, match=msg): + decodeBase.str.decode('cp1252') + + f = lambda x: x.decode('cp1252', 'ignore') + result = decodeBase.str.decode('cp1252', 'ignore') + expected = xr.DataArray([f(x) for x in decodeBase.values.tolist()]) + assert_equal(result, expected) From adbd59a0498cce298d88d9383837c968bebae538 Mon Sep 17 00:00:00 2001 From: David Hoese Date: Mon, 10 Jun 2019 08:25:57 -0500 Subject: [PATCH 14/31] Fix 'to_masked_array' computing dask arrays twice (#3006) --- xarray/core/dataarray.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/xarray/core/dataarray.py b/xarray/core/dataarray.py index a492635dc67..6491bcc1b4a 100644 --- a/xarray/core/dataarray.py +++ b/xarray/core/dataarray.py @@ -1748,8 +1748,9 @@ def to_masked_array(self, copy=True): result : MaskedArray Masked where invalid values (nan or inf) occur. """ - isnull = pd.isnull(self.values) - return np.ma.MaskedArray(data=self.values, mask=isnull, copy=copy) + values = self.values # only compute lazy arrays once + isnull = pd.isnull(values) + return np.ma.MaskedArray(data=values, mask=isnull, copy=copy) def to_netcdf(self, *args, **kwargs): """Write DataArray contents to a netCDF file. From 43834ac8186a851b7ea5aed8657b72d62fa3695f Mon Sep 17 00:00:00 2001 From: Maximilian Roos <5635139+max-sixty@users.noreply.github.com> Date: Mon, 10 Jun 2019 22:45:47 -0400 Subject: [PATCH 15/31] dask-dev tests to allowed failures in travis (#3014) --- .travis.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.travis.yml b/.travis.yml index 155c0271b30..f351310459b 100644 --- a/.travis.yml +++ b/.travis.yml @@ -16,7 +16,6 @@ matrix: - env: - CONDA_ENV=py36 - EXTRA_FLAGS="--run-flaky --run-network-tests" - - env: CONDA_ENV=py36-dask-dev - env: CONDA_ENV=py36-pandas-dev - env: CONDA_ENV=py36-rasterio - env: CONDA_ENV=py36-zarr-dev @@ -25,6 +24,7 @@ matrix: - env: CONDA_ENV=py36-hypothesis allow_failures: + - env: CONDA_ENV=py36-dask-dev - env: - CONDA_ENV=py36 - EXTRA_FLAGS="--run-flaky --run-network-tests" From f3efb5bdaf87c5301106cd1122b14d131c7d464b Mon Sep 17 00:00:00 2001 From: Maximilian Roos <5635139+max-sixty@users.noreply.github.com> Date: Tue, 11 Jun 2019 11:01:18 -0400 Subject: [PATCH 16/31] Pytest capture uses match, not message (#3011) * Pytest uses match, not message * correct messages * whatsnew --- doc/whats-new.rst | 3 +++ xarray/tests/test_dataarray.py | 4 ++-- xarray/tests/test_dataset.py | 10 +++++----- xarray/tests/test_plot.py | 3 ++- 4 files changed, 12 insertions(+), 8 deletions(-) diff --git a/doc/whats-new.rst b/doc/whats-new.rst index f397178ca5d..b01c53e76e2 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -76,6 +76,9 @@ Bug fixes By `Maximilian Roos `_. - Fixed performance issues with cftime installed (:issue:`3000`) By `0x0L `_. +- Replace incorrect usages of `message` in pytest assertions + with `match` (:issue:`3011`) + By `Maximilian Roos `_. .. _whats-new.0.12.1: diff --git a/xarray/tests/test_dataarray.py b/xarray/tests/test_dataarray.py index 3580155781a..68421d96b1b 100644 --- a/xarray/tests/test_dataarray.py +++ b/xarray/tests/test_dataarray.py @@ -1177,7 +1177,7 @@ def test_reset_coords(self): dims=['x', 'y'], name='foo') assert_identical(actual, expected) - with pytest.warns(FutureWarning, message='The inplace argument'): + with pytest.warns(FutureWarning, match='The inplace argument'): with raises_regex(ValueError, 'cannot reset coord'): data = data.reset_coords(inplace=True) with raises_regex(ValueError, 'cannot be found'): @@ -1540,7 +1540,7 @@ def test_reorder_levels(self): obj = self.mda.reorder_levels(x=['level_2', 'level_1']) assert_identical(obj, expected) - with pytest.warns(FutureWarning, message='The inplace argument'): + with pytest.warns(FutureWarning, match='The inplace argument'): array = self.mda.copy() array.reorder_levels(x=['level_2', 'level_1'], inplace=True) assert_identical(array, expected) diff --git a/xarray/tests/test_dataset.py b/xarray/tests/test_dataset.py index ecacf43caf4..dd7d2a98333 100644 --- a/xarray/tests/test_dataset.py +++ b/xarray/tests/test_dataset.py @@ -2258,7 +2258,7 @@ def test_set_index(self): obj = ds.set_index(x=mindex.names) assert_identical(obj, expected) - with pytest.warns(FutureWarning, message='The inplace argument'): + with pytest.warns(FutureWarning, match='The inplace argument'): ds.set_index(x=mindex.names, inplace=True) assert_identical(ds, expected) @@ -2278,7 +2278,7 @@ def test_reset_index(self): obj = ds.reset_index('x') assert_identical(obj, expected) - with pytest.warns(FutureWarning, message='The inplace argument'): + with pytest.warns(FutureWarning, match='The inplace argument'): ds.reset_index('x', inplace=True) assert_identical(ds, expected) @@ -2291,7 +2291,7 @@ def test_reorder_levels(self): reindexed = ds.reorder_levels(x=['level_2', 'level_1']) assert_identical(reindexed, expected) - with pytest.warns(FutureWarning, message='The inplace argument'): + with pytest.warns(FutureWarning, match='The inplace argument'): ds.reorder_levels(x=['level_2', 'level_1'], inplace=True) assert_identical(ds, expected) @@ -2375,7 +2375,7 @@ def test_update(self): assert actual_result is actual assert_identical(expected, actual) - with pytest.warns(FutureWarning, message='The inplace argument'): + with pytest.warns(FutureWarning, match='The inplace argument'): actual = data.update(data, inplace=False) expected = data assert actual is not expected @@ -4615,7 +4615,7 @@ def test_dataset_constructor_aligns_to_explicit_coords( def test_error_message_on_set_supplied(): - with pytest.raises(TypeError, message='has invalid type set'): + with pytest.raises(TypeError, match="has invalid type "): xr.Dataset(dict(date=[1, 2, 3], sec={4})) diff --git a/xarray/tests/test_plot.py b/xarray/tests/test_plot.py index 84510da65fe..a79becb3bda 100644 --- a/xarray/tests/test_plot.py +++ b/xarray/tests/test_plot.py @@ -210,7 +210,8 @@ def test_2d_coords_line_plot(self): hdl = da.plot.line(x='lon', hue='y') assert len(hdl) == 4 - with pytest.raises(ValueError, message='If x or y are 2D '): + with pytest.raises( + ValueError, match="For 2D inputs, hue must be a dimension"): da.plot.line(x='lon', hue='lat') def test_2d_before_squeeze(self): From fda60563bbca7eea6fabd6603750f359e1ad00ef Mon Sep 17 00:00:00 2001 From: Maximilian Roos <5635139+max-sixty@users.noreply.github.com> Date: Tue, 11 Jun 2019 18:58:25 -0400 Subject: [PATCH 17/31] Pandas labels deprecation (#3016) * Pandas deprecation * allow for older versions of pandas * update docs --- xarray/core/dataarray.py | 4 ++-- xarray/core/pdcompat.py | 6 +++--- xarray/tests/test_dataarray.py | 8 ++++++-- 3 files changed, 11 insertions(+), 7 deletions(-) diff --git a/xarray/core/dataarray.py b/xarray/core/dataarray.py index 6491bcc1b4a..094b8615880 100644 --- a/xarray/core/dataarray.py +++ b/xarray/core/dataarray.py @@ -1351,7 +1351,7 @@ def stack(self, dimensions=None, **dimensions_kwargs): >>> stacked = arr.stack(z=('x', 'y')) >>> stacked.indexes['z'] MultiIndex(levels=[['a', 'b'], [0, 1, 2]], - labels=[[0, 0, 0, 1, 1, 1], [0, 1, 2, 0, 1, 2]], + codes=[[0, 0, 0, 1, 1, 1], [0, 1, 2, 0, 1, 2]], names=['x', 'y']) See also @@ -1394,7 +1394,7 @@ def unstack(self, dim=None): >>> stacked = arr.stack(z=('x', 'y')) >>> stacked.indexes['z'] MultiIndex(levels=[['a', 'b'], [0, 1, 2]], - labels=[[0, 0, 0, 1, 1, 1], [0, 1, 2, 0, 1, 2]], + codes=[[0, 0, 0, 1, 1, 1], [0, 1, 2, 0, 1, 2]], names=['x', 'y']) >>> roundtripped = stacked.unstack() >>> arr.identical(roundtripped) diff --git a/xarray/core/pdcompat.py b/xarray/core/pdcompat.py index f76634c65aa..7e2b0bbf6c4 100644 --- a/xarray/core/pdcompat.py +++ b/xarray/core/pdcompat.py @@ -57,15 +57,15 @@ def remove_unused_levels(self): -------- >>> i = pd.MultiIndex.from_product([range(2), list('ab')]) MultiIndex(levels=[[0, 1], ['a', 'b']], - labels=[[0, 0, 1, 1], [0, 1, 0, 1]]) + codes=[[0, 0, 1, 1], [0, 1, 0, 1]]) >>> i[2:] MultiIndex(levels=[[0, 1], ['a', 'b']], - labels=[[1, 1], [0, 1]]) + codes=[[1, 1], [0, 1]]) The 0 from the first level is not represented and can be removed >>> i[2:].remove_unused_levels() MultiIndex(levels=[[1], ['a', 'b']], - labels=[[0, 0], [0, 1]]) + codes=[[0, 0], [0, 1]]) """ import pandas.core.algorithms as algos diff --git a/xarray/tests/test_dataarray.py b/xarray/tests/test_dataarray.py index 68421d96b1b..ab1f56abd4c 100644 --- a/xarray/tests/test_dataarray.py +++ b/xarray/tests/test_dataarray.py @@ -1762,8 +1762,12 @@ def test_stack_unstack(self): # test GH3000 a = orig[:0, :1].stack(dim=('x', 'y')).dim.to_index() - b = pd.MultiIndex(levels=[pd.Int64Index([]), pd.Int64Index([0])], - labels=[[], []], names=['x', 'y']) + if pd.__version__ < '0.24.0': + b = pd.MultiIndex(levels=[pd.Int64Index([]), pd.Int64Index([0])], + labels=[[], []], names=['x', 'y']) + else: + b = pd.MultiIndex(levels=[pd.Int64Index([]), pd.Int64Index([0])], + codes=[[], []], names=['x', 'y']) pd.util.testing.assert_index_equal(a, b) actual = orig.stack(z=['x', 'y']).unstack('z').drop(['x', 'y']) From 3429ca2aa2a07cd77797f5a4c036d6a325f2003f Mon Sep 17 00:00:00 2001 From: Maximilian Roos <5635139+max-sixty@users.noreply.github.com> Date: Wed, 12 Jun 2019 00:56:36 -0400 Subject: [PATCH 18/31] Use flake8 rather than pycodestyle (#3010) * A few flake fixes * isort * bunch of flake8 errors fixed * flake8 config * pep8speaks config * run flake8 in travis * docs to flake8 * pep8speaks configs inherited from setup.cfg * too much isort, skipped base __init__ * imports * install flake8 in travis * update 3.6 reqs --- .pep8speaks.yml | 18 ++++-------------- .travis.yml | 2 +- asv_bench/benchmarks/__init__.py | 5 ++--- ci/requirements-py36.yml | 2 +- ci/requirements-py37.yml | 2 +- doc/conf.py | 2 +- doc/contributing.rst | 8 +++----- doc/examples/_code/weather_data_setup.py | 2 +- setup.cfg | 13 +++++++++++-- xarray/__init__.py | 1 + xarray/backends/__init__.py | 8 ++++---- xarray/backends/file_manager.py | 2 +- xarray/backends/locks.py | 2 +- xarray/backends/netCDF4_.py | 3 ++- xarray/coding/variables.py | 2 +- xarray/core/accessor_str.py | 1 - xarray/core/alignment.py | 2 +- xarray/core/combine.py | 4 ++-- xarray/core/common.py | 6 +++--- xarray/core/dataset.py | 7 +++---- xarray/core/duck_array_ops.py | 2 +- xarray/core/missing.py | 2 +- xarray/core/options.py | 2 +- xarray/core/resample_cftime.py | 10 ++++++---- xarray/core/rolling.py | 4 ++-- xarray/core/utils.py | 8 ++++---- xarray/plot/__init__.py | 4 +--- xarray/plot/facetgrid.py | 7 +++---- xarray/plot/utils.py | 5 ++--- xarray/testing.py | 3 +-- xarray/tests/__init__.py | 14 +++++++------- xarray/tests/test_accessor_str.py | 10 +++++----- xarray/tests/test_backends.py | 22 +++++++++++----------- xarray/tests/test_backends_lru_cache.py | 1 + xarray/tests/test_cftime_offsets.py | 4 ++-- xarray/tests/test_cftimeindex.py | 2 +- xarray/tests/test_cftimeindex_resample.py | 6 +++--- xarray/tests/test_coding_times.py | 2 +- xarray/tests/test_combine.py | 4 ++-- xarray/tests/test_dataarray.py | 20 ++++---------------- xarray/tests/test_dataset.py | 8 ++++---- xarray/tests/test_distributed.py | 2 ++ xarray/tests/test_duck_array_ops.py | 4 ++-- xarray/tests/test_interp.py | 2 +- xarray/tests/test_merge.py | 2 +- xarray/tests/test_plot.py | 7 +++---- xarray/tutorial.py | 1 - 47 files changed, 116 insertions(+), 134 deletions(-) diff --git a/.pep8speaks.yml b/.pep8speaks.yml index 018003f2223..8d87864e426 100644 --- a/.pep8speaks.yml +++ b/.pep8speaks.yml @@ -1,16 +1,6 @@ -# File : .pep8speaks.yml - -# This should be kept in sync with the duplicate config in the [pycodestyle] -# block of setup.cfg. +# https://github.com/OrkoHunter/pep8speaks for more info +# pep8speaks will use the flake8 configs in `setup.cfg` scanner: - diff_only: False # If True, errors caused by only the patch are shown - -pycodestyle: - max-line-length: 79 - ignore: # Errors and warnings to ignore - - E402 # module level import not at top of file - - E731 # do not assign a lambda expression, use a def - - E741 # ambiguous variable name - - W503 # line break before binary operator - - W504 # line break after binary operator + diff_only: False + linter: flake8 diff --git a/.travis.yml b/.travis.yml index f351310459b..913c5e1c0f7 100644 --- a/.travis.yml +++ b/.travis.yml @@ -60,7 +60,7 @@ script: cd doc; sphinx-build -n -j auto -b html -d _build/doctrees . _build/html; elif [[ "$CONDA_ENV" == "lint" ]]; then - pycodestyle xarray ; + flake8 ; elif [[ "$CONDA_ENV" == "py36-hypothesis" ]]; then pytest properties ; else diff --git a/asv_bench/benchmarks/__init__.py b/asv_bench/benchmarks/__init__.py index 997fdfd0db0..d0eb6282fce 100644 --- a/asv_bench/benchmarks/__init__.py +++ b/asv_bench/benchmarks/__init__.py @@ -1,6 +1,5 @@ -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function +from __future__ import absolute_import, division, print_function + import itertools import numpy as np diff --git a/ci/requirements-py36.yml b/ci/requirements-py36.yml index d6dafd8d540..aab926ac6aa 100644 --- a/ci/requirements-py36.yml +++ b/ci/requirements-py36.yml @@ -14,7 +14,7 @@ dependencies: - pytest-cov - pytest-env - coveralls - - pycodestyle + - flake8 - numpy>=1.12 - pandas>=0.19 - scipy diff --git a/ci/requirements-py37.yml b/ci/requirements-py37.yml index c5f5d71b8e5..fe5afd589c8 100644 --- a/ci/requirements-py37.yml +++ b/ci/requirements-py37.yml @@ -15,7 +15,7 @@ dependencies: - pytest-cov - pytest-env - coveralls - - pycodestyle + - flake8 - numpy>=1.12 - pandas>=0.19 - scipy diff --git a/doc/conf.py b/doc/conf.py index 322741556b6..237669460b2 100644 --- a/doc/conf.py +++ b/doc/conf.py @@ -13,11 +13,11 @@ # serve to show the default. from __future__ import absolute_import, division, print_function -from contextlib import suppress import datetime import os import subprocess import sys +from contextlib import suppress import xarray diff --git a/doc/contributing.rst b/doc/contributing.rst index fba09497abe..651c1d47db5 100644 --- a/doc/contributing.rst +++ b/doc/contributing.rst @@ -351,11 +351,11 @@ the more common ``PEP8`` issues: - passing arguments should have spaces after commas, e.g. ``foo(arg1, arg2, kw1='bar')`` :ref:`Continuous Integration ` will run -the `pycodestyle `_ tool +the `flake8 `_ tool and report any stylistic errors in your code. Therefore, it is helpful before -submitting code to run the check yourself:: +submitting code to run the check yourself: - pycodestyle xarray + flake8 Other recommended but optional tools for checking code quality (not currently enforced in CI): @@ -363,8 +363,6 @@ enforced in CI): - `mypy `_ performs static type checking, which can make it easier to catch bugs. Please run ``mypy xarray`` if you annotate any code with `type hints `_. -- `flake8 `_ includes a few more automated - checks than those enforced by pycodestyle. - `isort `_ will highlight incorrectly sorted imports. ``isort -y`` will automatically fix them. See also `flake8-isort `_. diff --git a/doc/examples/_code/weather_data_setup.py b/doc/examples/_code/weather_data_setup.py index 89470542d5a..d3a3e2d065a 100644 --- a/doc/examples/_code/weather_data_setup.py +++ b/doc/examples/_code/weather_data_setup.py @@ -1,6 +1,6 @@ import numpy as np import pandas as pd -import seaborn as sns # pandas aware plotting library +import seaborn as sns # noqa, pandas aware plotting library import xarray as xr diff --git a/setup.cfg b/setup.cfg index 18922b1647a..51449138780 100644 --- a/setup.cfg +++ b/setup.cfg @@ -11,9 +11,18 @@ env = UVCDAT_ANONYMOUS_LOG=no # This should be kept in sync with .pep8speaks.yml -[pycodestyle] +[flake8] max-line-length=79 -ignore=E402,E731,E741,W503,W504 +ignore= + E402 + E731 + E741 + W503 + W504 + # Unused imports; TODO: Allow typing to work without triggering errors + F401 +exclude= + doc [isort] default_section=THIRDPARTY diff --git a/xarray/__init__.py b/xarray/__init__.py index 506cb46de26..9eaa705e108 100644 --- a/xarray/__init__.py +++ b/xarray/__init__.py @@ -1,3 +1,4 @@ +""" isort:skip_file """ # flake8: noqa from ._version import get_versions diff --git a/xarray/backends/__init__.py b/xarray/backends/__init__.py index 9b9e04d9346..292a6d68523 100644 --- a/xarray/backends/__init__.py +++ b/xarray/backends/__init__.py @@ -3,16 +3,16 @@ DataStores provide a uniform interface for saving and loading data in different formats. They should not be used directly, but rather through Dataset objects. """ -from .common import AbstractDataStore -from .file_manager import FileManager, CachingFileManager, DummyFileManager from .cfgrib_ import CfGribDataStore +from .common import AbstractDataStore +from .file_manager import CachingFileManager, DummyFileManager, FileManager +from .h5netcdf_ import H5NetCDFStore from .memory import InMemoryDataStore from .netCDF4_ import NetCDF4DataStore +from .pseudonetcdf_ import PseudoNetCDFDataStore from .pydap_ import PydapDataStore from .pynio_ import NioDataStore from .scipy_ import ScipyDataStore -from .h5netcdf_ import H5NetCDFStore -from .pseudonetcdf_ import PseudoNetCDFDataStore from .zarr import ZarrStore __all__ = [ diff --git a/xarray/backends/file_manager.py b/xarray/backends/file_manager.py index 5955ef54d6e..0d11632fa67 100644 --- a/xarray/backends/file_manager.py +++ b/xarray/backends/file_manager.py @@ -1,7 +1,7 @@ import contextlib import threading -from typing import Any, Dict import warnings +from typing import Any, Dict from ..core import utils from ..core.options import OPTIONS diff --git a/xarray/backends/locks.py b/xarray/backends/locks.py index 65150562538..bb63186ce3a 100644 --- a/xarray/backends/locks.py +++ b/xarray/backends/locks.py @@ -1,7 +1,7 @@ import multiprocessing import threading -from typing import Any, MutableMapping import weakref +from typing import Any, MutableMapping try: from dask.utils import SerializableLock diff --git a/xarray/backends/netCDF4_.py b/xarray/backends/netCDF4_.py index 2396523dca7..268afcfcea5 100644 --- a/xarray/backends/netCDF4_.py +++ b/xarray/backends/netCDF4_.py @@ -174,7 +174,7 @@ def _force_native_endianness(var): # if endian exists, remove it from the encoding. var.encoding.pop('endian', None) # check to see if encoding has a value for endian its 'native' - if not var.encoding.get('endian', 'native') is 'native': + if not var.encoding.get('endian', 'native') == 'native': raise NotImplementedError("Attempt to write non-native endian type, " "this is not supported by the netCDF4 " "python library.") @@ -237,6 +237,7 @@ def _extract_nc4_variable_encoding(variable, raise_on_invalid=False, class GroupWrapper: """Wrap netCDF4.Group objects so closing them closes the root group.""" + def __init__(self, value): self.value = value diff --git a/xarray/coding/variables.py b/xarray/coding/variables.py index ae8b97c7352..8f5ffe8a38a 100644 --- a/xarray/coding/variables.py +++ b/xarray/coding/variables.py @@ -1,7 +1,7 @@ """Coders for individual Variable objects.""" -from typing import Any import warnings from functools import partial +from typing import Any import numpy as np import pandas as pd diff --git a/xarray/core/accessor_str.py b/xarray/core/accessor_str.py index 564593a032e..4a1983517eb 100644 --- a/xarray/core/accessor_str.py +++ b/xarray/core/accessor_str.py @@ -45,7 +45,6 @@ from .computation import apply_ufunc - _cpython_optimized_encoders = ( "utf-8", "utf8", "latin-1", "latin1", "iso-8859-1", "mbcs", "ascii" ) diff --git a/xarray/core/alignment.py b/xarray/core/alignment.py index 295f69a2afc..031861b0ccf 100644 --- a/xarray/core/alignment.py +++ b/xarray/core/alignment.py @@ -8,7 +8,7 @@ import numpy as np import pandas as pd -from . import utils, dtypes +from . import dtypes, utils from .indexing import get_indexer_nd from .utils import is_dict_like, is_full_slice from .variable import IndexVariable, Variable diff --git a/xarray/core/combine.py b/xarray/core/combine.py index 6d922064f6f..0b18aa47dee 100644 --- a/xarray/core/combine.py +++ b/xarray/core/combine.py @@ -4,12 +4,12 @@ import pandas as pd -from . import utils, dtypes +from . import dtypes, utils from .alignment import align +from .computation import result_name from .merge import merge from .variable import IndexVariable, Variable, as_variable from .variable import concat as concat_vars -from .computation import result_name def concat(objs, dim=None, data_vars='all', coords='different', diff --git a/xarray/core/common.py b/xarray/core/common.py index 00d0383a727..4e5133fd8c6 100644 --- a/xarray/core/common.py +++ b/xarray/core/common.py @@ -1,8 +1,9 @@ from collections import OrderedDict from contextlib import suppress from textwrap import dedent -from typing import (Any, Callable, Hashable, Iterable, Iterator, List, Mapping, - MutableMapping, Optional, Tuple, TypeVar, Union) +from typing import ( + Any, Callable, Hashable, Iterable, Iterator, List, Mapping, MutableMapping, + Optional, Tuple, TypeVar, Union) import numpy as np import pandas as pd @@ -13,7 +14,6 @@ from .pycompat import dask_array_type from .utils import Frozen, ReprObject, SortedKeysDict, either_dict_or_kwargs - # Used as a sentinel value to indicate a all dimensions ALL_DIMS = ReprObject('') diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index e9ec1445dd4..ced1dba09e2 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -8,7 +8,7 @@ from distutils.version import LooseVersion from numbers import Number from typing import ( - Any, Callable, Dict, List, Optional, Set, Tuple, TypeVar, Union, Sequence) + Any, Callable, Dict, List, Optional, Sequence, Set, Tuple, TypeVar, Union) import numpy as np import pandas as pd @@ -35,8 +35,7 @@ from .pycompat import TYPE_CHECKING, dask_array_type from .utils import ( Frozen, SortedKeysDict, _check_inplace, decode_numpy_dict_values, - either_dict_or_kwargs, ensure_us_time_resolution, hashable, is_dict_like, - maybe_wrap_array) + either_dict_or_kwargs, hashable, maybe_wrap_array) from .variable import IndexVariable, Variable, as_variable, broadcast_variables if TYPE_CHECKING: @@ -4145,7 +4144,7 @@ def _integrate_one(self, coord, datetime_unit=None): from .variable import Variable if coord not in self.variables and coord not in self.dims: - raise ValueError('Coordinate {} does not exist.'.format(dim)) + raise ValueError('Coordinate {} does not exist.'.format(coord)) coord_var = self[coord].variable if coord_var.ndim != 1: diff --git a/xarray/core/duck_array_ops.py b/xarray/core/duck_array_ops.py index b37e01cb7af..bc66eb71ced 100644 --- a/xarray/core/duck_array_ops.py +++ b/xarray/core/duck_array_ops.py @@ -4,9 +4,9 @@ accept or return xarray objects. """ import contextlib -from functools import partial import inspect import warnings +from functools import partial import numpy as np import pandas as pd diff --git a/xarray/core/missing.py b/xarray/core/missing.py index 3931512325e..6009983beb2 100644 --- a/xarray/core/missing.py +++ b/xarray/core/missing.py @@ -9,7 +9,7 @@ from . import utils from .common import _contains_datetime_like_objects from .computation import apply_ufunc -from .duck_array_ops import dask_array_type, datetime_to_numeric +from .duck_array_ops import dask_array_type from .utils import OrderedSet, is_scalar from .variable import Variable, broadcast_variables diff --git a/xarray/core/options.py b/xarray/core/options.py index d441a81d325..532d86a8f38 100644 --- a/xarray/core/options.py +++ b/xarray/core/options.py @@ -59,7 +59,7 @@ def _warn_on_setting_enable_cftimeindex(enable_cftimeindex): def _get_keep_attrs(default): global_choice = OPTIONS['keep_attrs'] - if global_choice is 'default': + if global_choice == 'default': return default elif global_choice in [True, False]: return global_choice diff --git a/xarray/core/resample_cftime.py b/xarray/core/resample_cftime.py index e7f41be8667..cac78aabe98 100644 --- a/xarray/core/resample_cftime.py +++ b/xarray/core/resample_cftime.py @@ -36,14 +36,16 @@ # ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE # POSSIBILITY OF SUCH DAMAGE. -from ..coding.cftimeindex import CFTimeIndex -from ..coding.cftime_offsets import (cftime_range, normalize_date, - Day, MonthEnd, QuarterEnd, YearEnd, - CFTIME_TICKS, to_offset) import datetime + import numpy as np import pandas as pd +from ..coding.cftime_offsets import ( + CFTIME_TICKS, Day, MonthEnd, QuarterEnd, YearEnd, cftime_range, + normalize_date, to_offset) +from ..coding.cftimeindex import CFTimeIndex + class CFTimeGrouper: """This is a simple container for the grouping parameters that implements a diff --git a/xarray/core/rolling.py b/xarray/core/rolling.py index a884963bf06..4773512cdc4 100644 --- a/xarray/core/rolling.py +++ b/xarray/core/rolling.py @@ -7,8 +7,8 @@ from . import dtypes, duck_array_ops, utils from .dask_array_ops import dask_rolling_wrapper from .ops import ( - bn, has_bottleneck, inject_coarsen_methods, - inject_bottleneck_rolling_methods, inject_datasetrolling_methods) + bn, has_bottleneck, inject_bottleneck_rolling_methods, + inject_coarsen_methods, inject_datasetrolling_methods) from .pycompat import dask_array_type diff --git a/xarray/core/utils.py b/xarray/core/utils.py index 386019a3dbf..9b762ab99c7 100644 --- a/xarray/core/utils.py +++ b/xarray/core/utils.py @@ -7,16 +7,16 @@ import re import warnings from collections import OrderedDict -from typing import (AbstractSet, Any, Callable, Container, Dict, Hashable, - Iterable, Iterator, Optional, Sequence, - Tuple, TypeVar, cast) +from typing import ( + AbstractSet, Any, Callable, Container, Dict, Hashable, Iterable, Iterator, + Mapping, MutableMapping, MutableSet, Optional, Sequence, Tuple, TypeVar, + cast) import numpy as np import pandas as pd from .pycompat import dask_array_type -from typing import Mapping, MutableMapping, MutableSet try: # Fix typed collections in Python 3.5.0~3.5.2 from .pycompat import Mapping, MutableMapping, MutableSet # noqa: F811 except ImportError: diff --git a/xarray/plot/__init__.py b/xarray/plot/__init__.py index 51712e78bf8..adda541c21d 100644 --- a/xarray/plot/__init__.py +++ b/xarray/plot/__init__.py @@ -1,7 +1,5 @@ -from .plot import (plot, line, step, contourf, contour, - hist, imshow, pcolormesh) - from .facetgrid import FacetGrid +from .plot import contour, contourf, hist, imshow, line, pcolormesh, plot, step __all__ = [ 'plot', diff --git a/xarray/plot/facetgrid.py b/xarray/plot/facetgrid.py index 4a8d77d7b86..9d2b4848319 100644 --- a/xarray/plot/facetgrid.py +++ b/xarray/plot/facetgrid.py @@ -1,14 +1,13 @@ import functools import itertools import warnings -from inspect import getfullargspec import numpy as np from ..core.formatting import format_item from .utils import ( - _infer_xy_labels, _process_cmap_cbar_kwargs, - import_matplotlib_pyplot, label_from_attrs) + _infer_xy_labels, _process_cmap_cbar_kwargs, import_matplotlib_pyplot, + label_from_attrs) # Overrides axes.labelsize, xtick.major.size, ytick.major.size # from mpl.rcParams @@ -483,7 +482,7 @@ def map(self, func, *args, **kwargs): # TODO: better way to verify that an artist is mappable? # https://stackoverflow.com/questions/33023036/is-it-possible-to-detect-if-a-matplotlib-artist-is-a-mappable-suitable-for-use-w#33023522 if (maybe_mappable and - hasattr(maybe_mappable, 'autoscale_None')): + hasattr(maybe_mappable, 'autoscale_None')): self._mappables.append(maybe_mappable) self._finalize_grid(*args[:2]) diff --git a/xarray/plot/utils.py b/xarray/plot/utils.py index 0a507993cd6..18215479d8c 100644 --- a/xarray/plot/utils.py +++ b/xarray/plot/utils.py @@ -2,15 +2,14 @@ import textwrap import warnings from datetime import datetime +from distutils.version import LooseVersion +from inspect import getfullargspec import numpy as np import pandas as pd -from inspect import getfullargspec - from ..core.options import OPTIONS from ..core.utils import is_scalar -from distutils.version import LooseVersion try: import nc_time_axis diff --git a/xarray/testing.py b/xarray/testing.py index eb8a0e8603d..ed015181dfd 100644 --- a/xarray/testing.py +++ b/xarray/testing.py @@ -4,8 +4,7 @@ import numpy as np import pandas as pd -from xarray.core import duck_array_ops -from xarray.core import formatting +from xarray.core import duck_array_ops, formatting from xarray.core.indexes import default_indexes diff --git a/xarray/tests/__init__.py b/xarray/tests/__init__.py index 5e559fce526..d3fe5e167a6 100644 --- a/xarray/tests/__init__.py +++ b/xarray/tests/__init__.py @@ -1,19 +1,19 @@ +import importlib +import re import warnings from contextlib import contextmanager from distutils import version -import re -import importlib -from unittest import mock +from unittest import mock # noqa import numpy as np -from numpy.testing import assert_array_equal # noqa: F401 -from xarray.core.duck_array_ops import allclose_or_equiv # noqa import pytest +from numpy.testing import assert_array_equal # noqa: F401 +import xarray.testing from xarray.core import utils -from xarray.core.options import set_options +from xarray.core.duck_array_ops import allclose_or_equiv # noqa from xarray.core.indexing import ExplicitlyIndexed -import xarray.testing +from xarray.core.options import set_options from xarray.plot.utils import import_seaborn try: diff --git a/xarray/tests/test_accessor_str.py b/xarray/tests/test_accessor_str.py index 26d5e385df3..800096b806b 100644 --- a/xarray/tests/test_accessor_str.py +++ b/xarray/tests/test_accessor_str.py @@ -39,12 +39,12 @@ import re -import pytest import numpy as np +import pytest + import xarray as xr -from . import ( - assert_array_equal, assert_equal, has_dask, raises_regex, requires_dask) +from . import assert_equal, requires_dask @pytest.fixture(params=[np.str_, np.bytes_]) @@ -138,14 +138,14 @@ def test_replace(dtype): def test_replace_callable(): values = xr.DataArray(['fooBAD__barBAD']) # test with callable - repl = lambda m: m.group(0).swapcase() + repl = lambda m: m.group(0).swapcase() # noqa result = values.str.replace('[a-z][A-Z]{2}', repl, n=2) exp = xr.DataArray(['foObaD__baRbaD']) assert_equal(result, exp) # test regex named groups values = xr.DataArray(['Foo Bar Baz']) pat = r"(?P\w+) (?P\w+) (?P\w+)" - repl = lambda m: m.group('middle').swapcase() + repl = lambda m: m.group('middle').swapcase() # noqa result = values.str.replace(pat, repl) exp = xr.DataArray(['bAR']) assert_equal(result, exp) diff --git a/xarray/tests/test_backends.py b/xarray/tests/test_backends.py index ad66ecf1286..e71181f52c9 100644 --- a/xarray/tests/test_backends.py +++ b/xarray/tests/test_backends.py @@ -2,15 +2,15 @@ import itertools import math import os.path -from pathlib import Path import pickle import shutil import sys import tempfile -from typing import Optional import warnings from contextlib import ExitStack from io import BytesIO +from pathlib import Path +from typing import Optional import numpy as np import pandas as pd @@ -18,26 +18,26 @@ import xarray as xr from xarray import ( - DataArray, Dataset, backends, open_dataarray, open_dataset, open_mfdataset, - save_mfdataset, load_dataset, load_dataarray) + DataArray, Dataset, backends, load_dataarray, load_dataset, open_dataarray, + open_dataset, open_mfdataset, save_mfdataset) from xarray.backends.common import robust_getitem from xarray.backends.netCDF4_ import _extract_nc4_variable_encoding from xarray.backends.pydap_ import PydapDataStore +from xarray.coding.variables import SerializationWarning from xarray.core import indexing from xarray.core.options import set_options from xarray.core.pycompat import dask_array_type from xarray.tests import mock -from xarray.coding.variables import SerializationWarning from . import ( assert_allclose, assert_array_equal, assert_equal, assert_identical, has_dask, has_netCDF4, has_scipy, network, raises_regex, requires_cfgrib, - requires_cftime, requires_dask, requires_h5netcdf, requires_netCDF4, - requires_pathlib, requires_pseudonetcdf, requires_pydap, requires_pynio, - requires_rasterio, requires_scipy, requires_scipy_or_netCDF4, - requires_zarr, requires_h5fileobj) -from .test_coding_times import (_STANDARD_CALENDARS, _NON_STANDARD_CALENDARS, - _ALL_CALENDARS) + requires_cftime, requires_dask, requires_h5fileobj, requires_h5netcdf, + requires_netCDF4, requires_pathlib, requires_pseudonetcdf, requires_pydap, + requires_pynio, requires_rasterio, requires_scipy, + requires_scipy_or_netCDF4, requires_zarr) +from .test_coding_times import ( + _ALL_CALENDARS, _NON_STANDARD_CALENDARS, _STANDARD_CALENDARS) from .test_dataset import create_test_data try: diff --git a/xarray/tests/test_backends_lru_cache.py b/xarray/tests/test_backends_lru_cache.py index d64d718f2f7..aa97f5fb4cb 100644 --- a/xarray/tests/test_backends_lru_cache.py +++ b/xarray/tests/test_backends_lru_cache.py @@ -1,4 +1,5 @@ from unittest import mock + import pytest from xarray.backends.lru_cache import LRUCache diff --git a/xarray/tests/test_cftime_offsets.py b/xarray/tests/test_cftime_offsets.py index 1cf257c96eb..b3560fe3039 100644 --- a/xarray/tests/test_cftime_offsets.py +++ b/xarray/tests/test_cftime_offsets.py @@ -6,8 +6,8 @@ from xarray import CFTimeIndex from xarray.coding.cftime_offsets import ( - _MONTH_ABBREVIATIONS, BaseCFTimeOffset, Day, Hour, Minute, Second, - MonthBegin, MonthEnd, YearBegin, YearEnd, QuarterBegin, QuarterEnd, + _MONTH_ABBREVIATIONS, BaseCFTimeOffset, Day, Hour, Minute, MonthBegin, + MonthEnd, QuarterBegin, QuarterEnd, Second, YearBegin, YearEnd, _days_in_month, cftime_range, get_date_type, to_cftime_datetime, to_offset) cftime = pytest.importorskip('cftime') diff --git a/xarray/tests/test_cftimeindex.py b/xarray/tests/test_cftimeindex.py index 999af4c9f86..56c01fbdc28 100644 --- a/xarray/tests/test_cftimeindex.py +++ b/xarray/tests/test_cftimeindex.py @@ -8,7 +8,7 @@ from xarray.coding.cftimeindex import ( CFTimeIndex, _parse_array_of_cftime_strings, _parse_iso8601_with_reso, _parsed_string_to_bounds, assert_all_valid_date_type, parse_iso8601) -from xarray.tests import assert_array_equal, assert_allclose, assert_identical +from xarray.tests import assert_array_equal, assert_identical from . import ( has_cftime, has_cftime_1_0_2_1, has_cftime_or_netCDF4, raises_regex, diff --git a/xarray/tests/test_cftimeindex_resample.py b/xarray/tests/test_cftimeindex_resample.py index 7aca4492680..108b303e0c0 100644 --- a/xarray/tests/test_cftimeindex_resample.py +++ b/xarray/tests/test_cftimeindex_resample.py @@ -1,12 +1,12 @@ -import pytest - import datetime + import numpy as np import pandas as pd +import pytest + import xarray as xr from xarray.core.resample_cftime import CFTimeGrouper - pytest.importorskip('cftime') pytest.importorskip('pandas', minversion='0.24') diff --git a/xarray/tests/test_coding_times.py b/xarray/tests/test_coding_times.py index d40abd4acc3..421637dd7e3 100644 --- a/xarray/tests/test_coding_times.py +++ b/xarray/tests/test_coding_times.py @@ -15,7 +15,7 @@ from . import ( assert_array_equal, has_cftime, has_cftime_or_netCDF4, has_dask, - requires_cftime_or_netCDF4, requires_cftime) + requires_cftime, requires_cftime_or_netCDF4) try: from pandas.errors import OutOfBoundsDatetime diff --git a/xarray/tests/test_combine.py b/xarray/tests/test_combine.py index a477df0b0d4..e9b63dd18fc 100644 --- a/xarray/tests/test_combine.py +++ b/xarray/tests/test_combine.py @@ -14,8 +14,8 @@ _infer_tile_ids_from_nested_list, _new_tile_id) from . import ( - InaccessibleArray, assert_array_equal, - assert_equal, assert_identical, raises_regex, requires_dask) + InaccessibleArray, assert_array_equal, assert_equal, assert_identical, + raises_regex, requires_dask) from .test_dataset import create_test_data diff --git a/xarray/tests/test_dataarray.py b/xarray/tests/test_dataarray.py index ab1f56abd4c..fd9076e7f65 100644 --- a/xarray/tests/test_dataarray.py +++ b/xarray/tests/test_dataarray.py @@ -1,9 +1,9 @@ import pickle +import sys import warnings from collections import OrderedDict from copy import deepcopy from textwrap import dedent -import sys import numpy as np import pandas as pd @@ -12,7 +12,7 @@ import xarray as xr from xarray import ( DataArray, Dataset, IndexVariable, Variable, align, broadcast) -from xarray.coding.times import CFDatetimeCoder, _import_cftime +from xarray.coding.times import CFDatetimeCoder from xarray.convert import from_cdms2 from xarray.core import dtypes from xarray.core.common import ALL_DIMS, full_like @@ -1259,18 +1259,6 @@ def test_reindex_like_no_index(self): ValueError, 'different size for unlabeled'): foo.reindex_like(bar) - @pytest.mark.parametrize('fill_value', [dtypes.NA, 2, 2.0]) - def test_reindex_fill_value(self, fill_value): - foo = DataArray([10, 20], dims='y', coords={'y': [0, 1]}) - bar = DataArray([10, 20, 30], dims='y', coords={'y': [0, 1, 2]}) - if fill_value == dtypes.NA: - # if we supply the default, we expect the missing value for a - # float array - fill_value = np.nan - actual = x.reindex_like(bar, fill_value=fill_value) - expected = DataArray([10, 20, fill_value], coords=[('y', [0, 1, 2])]) - assert_identical(expected, actual) - @pytest.mark.filterwarnings('ignore:Indexer has dimensions') def test_reindex_regressions(self): # regression test for #279 @@ -1644,8 +1632,8 @@ def test_math_name(self): assert (a + a.rename(None)).name is None assert (a + a.rename('bar')).name is None assert (a + a).name == 'foo' - assert (+a['x']).name is 'x' - assert (a['x'] + 0).name is 'x' + assert (+a['x']).name == 'x' + assert (a['x'] + 0).name == 'x' assert (a + a['x']).name is None def test_math_with_coords(self): diff --git a/xarray/tests/test_dataset.py b/xarray/tests/test_dataset.py index dd7d2a98333..98e488552d1 100644 --- a/xarray/tests/test_dataset.py +++ b/xarray/tests/test_dataset.py @@ -22,8 +22,8 @@ from . import ( InaccessibleArray, UnexpectedDataAccess, assert_allclose, assert_array_equal, assert_equal, assert_identical, has_cftime, has_dask, - raises_regex, requires_bottleneck, requires_dask, requires_scipy, - source_ndarray, requires_cftime) + raises_regex, requires_bottleneck, requires_cftime, requires_dask, + requires_scipy, source_ndarray) try: import dask.array as da @@ -4756,9 +4756,9 @@ def test_rolling_wrapped_bottleneck(ds, name, center, min_periods, key): func_name = 'move_{0}'.format(name) actual = getattr(rolling_obj, name)() - if key is 'z1': # z1 does not depend on 'Time' axis. Stored as it is. + if key == 'z1': # z1 does not depend on 'Time' axis. Stored as it is. expected = ds[key] - elif key is 'z2': + elif key == 'z2': expected = getattr(bn, func_name)(ds[key].values, window=7, axis=0, min_count=min_periods) assert_array_equal(actual[key].values, expected) diff --git a/xarray/tests/test_distributed.py b/xarray/tests/test_distributed.py index 17f655cef8d..98c53ef2b12 100644 --- a/xarray/tests/test_distributed.py +++ b/xarray/tests/test_distributed.py @@ -1,4 +1,5 @@ """ isort:skip_file """ +# flake8: noqa: E402 - ignore linters re order of imports import pickle import pytest @@ -28,6 +29,7 @@ da = pytest.importorskip('dask.array') +loop = loop # loop is an imported fixture, which flake8 has issues ack-ing @pytest.fixture diff --git a/xarray/tests/test_duck_array_ops.py b/xarray/tests/test_duck_array_ops.py index 75ab5f52a1b..87a7a2863d3 100644 --- a/xarray/tests/test_duck_array_ops.py +++ b/xarray/tests/test_duck_array_ops.py @@ -7,13 +7,13 @@ import pytest from numpy import array, nan -from xarray import DataArray, Dataset, concat, cftime_range +from xarray import DataArray, Dataset, cftime_range, concat from xarray.core import dtypes, duck_array_ops from xarray.core.duck_array_ops import ( array_notnull_equiv, concatenate, count, first, gradient, last, mean, rolling_window, stack, where) from xarray.core.pycompat import dask_array_type -from xarray.testing import assert_allclose, assert_equal, assert_identical +from xarray.testing import assert_allclose, assert_equal from . import ( assert_array_equal, has_dask, has_np113, raises_regex, requires_cftime, diff --git a/xarray/tests/test_interp.py b/xarray/tests/test_interp.py index a11e4b9e79a..252f8bcacd4 100644 --- a/xarray/tests/test_interp.py +++ b/xarray/tests/test_interp.py @@ -6,8 +6,8 @@ from xarray.tests import ( assert_allclose, assert_equal, requires_cftime, requires_scipy) -from . import has_dask, has_scipy from ..coding.cftimeindex import _parse_array_of_cftime_strings +from . import has_dask, has_scipy from .test_dataset import create_test_data try: diff --git a/xarray/tests/test_merge.py b/xarray/tests/test_merge.py index 0d76db1d1ee..c45195eaa7f 100644 --- a/xarray/tests/test_merge.py +++ b/xarray/tests/test_merge.py @@ -2,7 +2,7 @@ import pytest import xarray as xr -from xarray.core import merge, dtypes +from xarray.core import dtypes, merge from . import raises_regex from .test_dataset import create_test_data diff --git a/xarray/tests/test_plot.py b/xarray/tests/test_plot.py index a79becb3bda..a0952cac47e 100644 --- a/xarray/tests/test_plot.py +++ b/xarray/tests/test_plot.py @@ -14,10 +14,9 @@ import_seaborn, label_from_attrs) from . import ( - assert_array_equal, assert_equal, raises_regex, requires_cftime, - requires_matplotlib, requires_matplotlib2, requires_seaborn, - requires_nc_time_axis) -from . import has_nc_time_axis + assert_array_equal, assert_equal, has_nc_time_axis, raises_regex, + requires_cftime, requires_matplotlib, requires_matplotlib2, + requires_nc_time_axis, requires_seaborn) # import mpl and change the backend before other mpl imports try: diff --git a/xarray/tutorial.py b/xarray/tutorial.py index 1a977450ed6..01d4f181d7f 100644 --- a/xarray/tutorial.py +++ b/xarray/tutorial.py @@ -7,7 +7,6 @@ ''' import hashlib import os as _os -import warnings from urllib.request import urlretrieve from .backends.api import open_dataset as _open_dataset From 2357851c79b88fb1e31e6487743fbbbffc029363 Mon Sep 17 00:00:00 2001 From: Deepak Cherian Date: Wed, 12 Jun 2019 09:32:26 -0600 Subject: [PATCH 19/31] More support for missing_value. (#2973) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * More support for missing_value. Fixes #2871 * lint fixes. * Use not equivalent instead of not equals check. * lint fix. * if → elif so we don't call fillna twice * Better fix. --- doc/whats-new.rst | 2 ++ xarray/coding/variables.py | 16 +++++++++++++++- xarray/conventions.py | 3 ++- xarray/tests/test_coding.py | 19 ++++++++++++++++++- 4 files changed, 37 insertions(+), 3 deletions(-) diff --git a/doc/whats-new.rst b/doc/whats-new.rst index b01c53e76e2..f8ec1e089b1 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -72,6 +72,8 @@ Bug fixes By `Deepak Cherian `_. +- Increased support for `missing_value` (:issue:`2871`) + By `Deepak Cherian `_. - Removed usages of `pytest.config`, which is deprecated (:issue:`2988`) By `Maximilian Roos `_. - Fixed performance issues with cftime installed (:issue:`3000`) diff --git a/xarray/coding/variables.py b/xarray/coding/variables.py index 8f5ffe8a38a..c23e45e44de 100644 --- a/xarray/coding/variables.py +++ b/xarray/coding/variables.py @@ -8,6 +8,7 @@ from ..core import dtypes, duck_array_ops, indexing from ..core.pycompat import dask_array_type +from ..core.utils import equivalent from ..core.variable import Variable @@ -145,11 +146,24 @@ class CFMaskCoder(VariableCoder): def encode(self, variable, name=None): dims, data, attrs, encoding = unpack_for_encoding(variable) - if encoding.get('_FillValue') is not None: + fv = encoding.get('_FillValue') + mv = encoding.get('missing_value') + + if fv is not None and mv is not None and not equivalent(fv, mv): + raise ValueError("Variable {!r} has multiple fill values {}. " + "Cannot encode data. " + .format(name, [fv, mv])) + + if fv is not None: fill_value = pop_to(encoding, attrs, '_FillValue', name=name) if not pd.isnull(fill_value): data = duck_array_ops.fillna(data, fill_value) + if mv is not None: + fill_value = pop_to(encoding, attrs, 'missing_value', name=name) + if not pd.isnull(fill_value) and fv is None: + data = duck_array_ops.fillna(data, fill_value) + return Variable(dims, data, attrs, encoding) def decode(self, variable, name=None): diff --git a/xarray/conventions.py b/xarray/conventions.py index 5f41639e890..3f8f76b08a2 100644 --- a/xarray/conventions.py +++ b/xarray/conventions.py @@ -82,7 +82,8 @@ def maybe_encode_nonstring_dtype(var, name=None): if dtype != var.dtype: if np.issubdtype(dtype, np.integer): if (np.issubdtype(var.dtype, np.floating) and - '_FillValue' not in var.attrs): + '_FillValue' not in var.attrs and + 'missing_value' not in var.attrs): warnings.warn('saving variable %s with floating ' 'point data as an integer dtype without ' 'any _FillValue to use for NaNs' % name, diff --git a/xarray/tests/test_coding.py b/xarray/tests/test_coding.py index 95c8ebc0b42..9f937ac7f5e 100644 --- a/xarray/tests/test_coding.py +++ b/xarray/tests/test_coding.py @@ -6,7 +6,7 @@ import xarray as xr from xarray.coding import variables -from . import assert_identical, requires_dask +from . import assert_equal, assert_identical, requires_dask with suppress(ImportError): import dask.array as da @@ -20,6 +20,23 @@ def test_CFMaskCoder_decode(): assert_identical(expected, encoded) +def test_CFMaskCoder_missing_value(): + expected = xr.DataArray(np.array([[26915, 27755, -9999, 27705], + [25595, -9999, 28315, -9999]]), + dims=['npts', 'ntimes'], + name='tmpk') + expected.attrs['missing_value'] = -9999 + + decoded = xr.decode_cf(expected.to_dataset()) + encoded, _ = xr.conventions.cf_encoder(decoded, decoded.attrs) + + assert_equal(encoded['tmpk'], expected.variable) + + decoded.tmpk.encoding['_FillValue'] = -9940 + with pytest.raises(ValueError): + encoded, _ = xr.conventions.cf_encoder(decoded, decoded.attrs) + + @requires_dask def test_CFMaskCoder_decode_dask(): original = xr.Variable(('x',), [0, -1, 1], {'_FillValue': -1}).chunk() From 7e4bf8623891c4e564bbaede706e1d69c614b74b Mon Sep 17 00:00:00 2001 From: Mathias Hauser Date: Wed, 12 Jun 2019 17:33:25 +0200 Subject: [PATCH 20/31] Feature/merge errormsg (#2971) * merge add errormsg iterable * double_errormsg * only error in loop * add test * update whats-new * ValueError -> TypeError * pep8 --- doc/whats-new.rst | 2 ++ xarray/core/merge.py | 11 ++++++++--- xarray/tests/test_merge.py | 9 +++++++++ 3 files changed, 19 insertions(+), 3 deletions(-) diff --git a/doc/whats-new.rst b/doc/whats-new.rst index f8ec1e089b1..1dc6e0cee1b 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -48,6 +48,8 @@ Enhancements helpful for avoiding file-lock errors when trying to write to files opened using ``open_dataset()`` or ``open_dataarray()``. (:issue:`2887`) By `Dan Nowacki `_. +- Better warning message when supplying invalid objects to ``xr.merge`` + (:issue:`2948`). By `Mathias Hauser `_. - Added ``strftime`` method to ``.dt`` accessor, making it simpler to hand a datetime ``DataArray`` to other code expecting formatted dates and times. (:issue:`2090`). By `Alan Brammer `_ and diff --git a/xarray/core/merge.py b/xarray/core/merge.py index 421ac39ebd8..c2c6aee7c22 100644 --- a/xarray/core/merge.py +++ b/xarray/core/merge.py @@ -533,9 +533,14 @@ def merge(objects, compat='no_conflicts', join='outer', fill_value=dtypes.NA): from .dataarray import DataArray from .dataset import Dataset - dict_like_objects = [ - obj.to_dataset() if isinstance(obj, DataArray) else obj - for obj in objects] + dict_like_objects = list() + for obj in objects: + if not (isinstance(obj, (DataArray, Dataset, dict))): + raise TypeError("objects must be an iterable containing only " + "Dataset(s), DataArray(s), and dictionaries.") + + obj = obj.to_dataset() if isinstance(obj, DataArray) else obj + dict_like_objects.append(obj) variables, coord_names, dims = merge_core(dict_like_objects, compat, join, fill_value=fill_value) diff --git a/xarray/tests/test_merge.py b/xarray/tests/test_merge.py index c45195eaa7f..20e0fae8daf 100644 --- a/xarray/tests/test_merge.py +++ b/xarray/tests/test_merge.py @@ -67,6 +67,15 @@ def test_merge_alignment_error(self): with raises_regex(ValueError, 'indexes .* not equal'): xr.merge([ds, other], join='exact') + def test_merge_wrong_input_error(self): + with raises_regex(TypeError, "objects must be an iterable"): + xr.merge([1]) + ds = xr.Dataset(coords={'x': [1, 2]}) + with raises_regex(TypeError, "objects must be an iterable"): + xr.merge({'a': ds}) + with raises_regex(TypeError, "objects must be an iterable"): + xr.merge([ds, 1]) + def test_merge_no_conflicts_single_var(self): ds1 = xr.Dataset({'a': ('x', [1, 2]), 'x': [0, 1]}) ds2 = xr.Dataset({'a': ('x', [2, 3]), 'x': [1, 2]}) From 85fc44156722891b333f1b559ed2bb5e33f87fa8 Mon Sep 17 00:00:00 2001 From: Maximilian Roos <5635139+max-sixty@users.noreply.github.com> Date: Fri, 14 Jun 2019 20:34:03 -0400 Subject: [PATCH 21/31] Add pytest markers to avoid warnings (#3023) * add pytest markers to avoid warnings * add a pytest straggling (was on different line so not picked up by regex) * whatsnew * revert autoformat --- doc/whats-new.rst | 3 +++ setup.cfg | 4 ++++ xarray/tests/test_dataset.py | 2 +- 3 files changed, 8 insertions(+), 1 deletion(-) diff --git a/doc/whats-new.rst b/doc/whats-new.rst index 1dc6e0cee1b..e62c7e87d44 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -83,6 +83,9 @@ Bug fixes - Replace incorrect usages of `message` in pytest assertions with `match` (:issue:`3011`) By `Maximilian Roos `_. +- Add explicit pytest markers, now required by pytest + (:issue:`3032`). + By `Maximilian Roos `_. .. _whats-new.0.12.1: diff --git a/setup.cfg b/setup.cfg index 51449138780..28cf17b92d0 100644 --- a/setup.cfg +++ b/setup.cfg @@ -23,6 +23,10 @@ ignore= F401 exclude= doc +markers = + flaky: flaky tests + network: tests requiring a network connection + slow: slow tests [isort] default_section=THIRDPARTY diff --git a/xarray/tests/test_dataset.py b/xarray/tests/test_dataset.py index 98e488552d1..812e2893db5 100644 --- a/xarray/tests/test_dataset.py +++ b/xarray/tests/test_dataset.py @@ -4625,7 +4625,7 @@ def test_error_message_on_set_supplied(): def test_constructor_raises_with_invalid_coords(unaligned_coords): with pytest.raises(ValueError, - message='not a subset of the DataArray dimensions'): + match='not a subset of the DataArray dimensions'): xr.DataArray([1, 2, 3], dims=['x'], coords=unaligned_coords) From c2a2a6efcaf2d279c78da4ba3a87ea96afe78be0 Mon Sep 17 00:00:00 2001 From: Maximilian Roos <5635139+max-sixty@users.noreply.github.com> Date: Fri, 14 Jun 2019 23:35:16 -0400 Subject: [PATCH 22/31] Update issue templates (#3019) --- .github/ISSUE_TEMPLATE/bug_report.md | 30 ++++++++++++++++++++++++++++ 1 file changed, 30 insertions(+) create mode 100644 .github/ISSUE_TEMPLATE/bug_report.md diff --git a/.github/ISSUE_TEMPLATE/bug_report.md b/.github/ISSUE_TEMPLATE/bug_report.md new file mode 100644 index 00000000000..cd14db03627 --- /dev/null +++ b/.github/ISSUE_TEMPLATE/bug_report.md @@ -0,0 +1,30 @@ +--- +name: Bug report +about: Create a report to help us improve +title: '' +labels: '' +assignees: '' + +--- + +#### MCVE Code Sample + +In order for the maintainers to efficiently understand and prioritize issues, we ask you post a "Minimal, Complete and Verifiable Example" (MCVE): http://matthewrocklin.com/blog/work/2018/02/28/minimal-bug-reports + +```python +# Your code here + +``` + +#### Problem Description + +[this should explain **why** the current behavior is a problem and why the expected output is a better solution.] + +#### Expected Output + +#### Output of ``xr.show_versions()`` + +
+# Paste the output here xr.show_versions() here + +
From 4c758e6ad282228ec52c277471db7cfb4f1f050f Mon Sep 17 00:00:00 2001 From: Maximilian Roos <5635139+max-sixty@users.noreply.github.com> Date: Tue, 18 Jun 2019 01:12:47 -0400 Subject: [PATCH 23/31] Check types in travis (#3024) * lint mypy typing * dot * revert autoformat * allow failures for typing --- .travis.yml | 12 +++++++++--- setup.cfg | 8 ++++++++ 2 files changed, 17 insertions(+), 3 deletions(-) diff --git a/.travis.yml b/.travis.yml index 913c5e1c0f7..ee242ebf818 100644 --- a/.travis.yml +++ b/.travis.yml @@ -21,6 +21,7 @@ matrix: - env: CONDA_ENV=py36-zarr-dev - env: CONDA_ENV=docs - env: CONDA_ENV=lint + - env: CONDA_ENV=typing - env: CONDA_ENV=py36-hypothesis allow_failures: @@ -30,6 +31,7 @@ matrix: - EXTRA_FLAGS="--run-flaky --run-network-tests" - env: CONDA_ENV=py36-pandas-dev - env: CONDA_ENV=py36-zarr-dev + - env: CONDA_ENV=typing before_install: - wget https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh -O miniconda.sh; @@ -40,9 +42,10 @@ before_install: - conda info -a install: - - if [[ "$CONDA_ENV" == "docs" ]]; then + - | + if [[ "$CONDA_ENV" == "docs" ]]; then conda env create -n test_env --file doc/environment.yml; - elif [[ "$CONDA_ENV" == "lint" ]]; then + elif [[ "$CONDA_ENV" == "lint" ]] || [[ "$CONDA_ENV" == "typing" ]] ; then conda env create -n test_env --file ci/requirements-py37.yml; else conda env create -n test_env --file ci/requirements-$CONDA_ENV.yml; @@ -56,11 +59,14 @@ script: - which python - python --version - python -OO -c "import xarray" - - if [[ "$CONDA_ENV" == "docs" ]]; then + - | + if [[ "$CONDA_ENV" == "docs" ]]; then cd doc; sphinx-build -n -j auto -b html -d _build/doctrees . _build/html; elif [[ "$CONDA_ENV" == "lint" ]]; then flake8 ; + elif [[ "$CONDA_ENV" == "typing" ]]; then + mypy . ; elif [[ "$CONDA_ENV" == "py36-hypothesis" ]]; then pytest properties ; else diff --git a/setup.cfg b/setup.cfg index 28cf17b92d0..cdfe2ec3e36 100644 --- a/setup.cfg +++ b/setup.cfg @@ -34,6 +34,8 @@ known_first_party=xarray multi_line_output=4 # Most of the numerical computing stack doesn't have type annotations yet. +[mypy-affine.*] +ignore_missing_imports = True [mypy-bottleneck.*] ignore_missing_imports = True [mypy-cdms2.*] @@ -85,6 +87,12 @@ ignore_missing_imports = True [mypy-zarr.*] ignore_missing_imports = True +# setuptools is not typed +[mypy-setup] +ignore_errors = True +# versioneer code +[mypy-versioneer.*] +ignore_errors = True # written by versioneer [mypy-xarray._version] ignore_errors = True From 145f25f69078f245313bb8e07b7c6af7509d0de8 Mon Sep 17 00:00:00 2001 From: Stephan Hoyer Date: Tue, 18 Jun 2019 17:05:08 +0300 Subject: [PATCH 24/31] More consistency checks (#2859) * Enable additional invariant checks in xarray's test suite * Tweak internal consistency checks * Various small fixes * Always use internal invariant checks * Fix coordinates type from DataArray.transpose() --- xarray/core/coordinates.py | 2 +- xarray/core/dataarray.py | 5 +- xarray/core/dataset.py | 30 +++++--- xarray/core/merge.py | 2 +- xarray/testing.py | 138 +++++++++++++++++++++++++++-------- xarray/tests/__init__.py | 13 ++-- xarray/tests/test_dataset.py | 5 ++ 7 files changed, 143 insertions(+), 52 deletions(-) diff --git a/xarray/core/coordinates.py b/xarray/core/coordinates.py index ea3eaa0f4f2..6a5795ccdc6 100644 --- a/xarray/core/coordinates.py +++ b/xarray/core/coordinates.py @@ -193,7 +193,7 @@ def _update_coords(self, coords): self._data._variables = variables self._data._coord_names.update(new_coord_names) - self._data._dims = dict(dims) + self._data._dims = dims self._data._indexes = None def __delitem__(self, key): diff --git a/xarray/core/dataarray.py b/xarray/core/dataarray.py index 094b8615880..2746c32a8dc 100644 --- a/xarray/core/dataarray.py +++ b/xarray/core/dataarray.py @@ -2,6 +2,7 @@ import sys import warnings from collections import OrderedDict +from typing import Any import numpy as np import pandas as pd @@ -67,7 +68,7 @@ def _infer_coords_and_dims(shape, coords, dims): for dim, coord in zip(dims, coords): var = as_variable(coord, name=dim) var.dims = (dim,) - new_coords[dim] = var + new_coords[dim] = var.to_index_variable() sizes = dict(zip(dims, shape)) for k, v in new_coords.items(): @@ -1442,7 +1443,7 @@ def transpose(self, *dims, transpose_coords=None) -> 'DataArray': variable = self.variable.transpose(*dims) if transpose_coords: - coords = {} + coords = OrderedDict() # type: OrderedDict[Any, Variable] for name, coord in self.coords.items(): coord_dims = tuple(dim for dim in dims if dim in coord.dims) coords[name] = coord.variable.transpose(*coord_dims) diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index ced1dba09e2..026be5ba4b0 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -100,7 +100,7 @@ def calculate_dimensions(variables): Returns dictionary mapping from dimension names to sizes. Raises ValueError if any of the dimension sizes conflict. """ - dims = OrderedDict() + dims = {} last_used = {} scalar_vars = set(k for k, v in variables.items() if not v.dims) for k, var in variables.items(): @@ -692,7 +692,7 @@ def _construct_direct(cls, variables, coord_names, dims, attrs=None, @classmethod def _from_vars_and_coord_names(cls, variables, coord_names, attrs=None): - dims = dict(calculate_dimensions(variables)) + dims = calculate_dimensions(variables) return cls._construct_direct(variables, coord_names, dims, attrs) # TODO(shoyer): renable type checking on this signature when pytype has a @@ -753,18 +753,20 @@ def _replace_with_new_dims( # type: ignore coord_names: set = None, attrs: 'Optional[OrderedDict]' = __default, indexes: 'Optional[OrderedDict[Any, pd.Index]]' = __default, + encoding: Optional[dict] = __default, inplace: bool = False, ) -> T: """Replace variables with recalculated dimensions.""" - dims = dict(calculate_dimensions(variables)) + dims = calculate_dimensions(variables) return self._replace( - variables, coord_names, dims, attrs, indexes, inplace=inplace) + variables, coord_names, dims, attrs, indexes, encoding, + inplace=inplace) def _replace_vars_and_dims( # type: ignore self: T, variables: 'OrderedDict[Any, Variable]' = None, coord_names: set = None, - dims: 'OrderedDict[Any, int]' = None, + dims: Dict[Any, int] = None, attrs: 'Optional[OrderedDict]' = __default, inplace: bool = False, ) -> T: @@ -1080,6 +1082,7 @@ def __delitem__(self, key): """ del self._variables[key] self._coord_names.discard(key) + self._dims = calculate_dimensions(self._variables) # mutable objects should not be hashable # https://github.com/python/mypy/issues/4266 @@ -2469,7 +2472,7 @@ def expand_dims(self, dim=None, axis=None, **dim_kwargs): else: # If dims includes a label of a non-dimension coordinate, # it will be promoted to a 1D coordinate with a single value. - variables[k] = v.set_dims(k) + variables[k] = v.set_dims(k).to_index_variable() new_dims = self._dims.copy() new_dims.update(dim) @@ -3556,12 +3559,15 @@ def from_dict(cls, d): def _unary_op(f, keep_attrs=False): @functools.wraps(f) def func(self, *args, **kwargs): - ds = self.coords.to_dataset() - for k in self.data_vars: - ds._variables[k] = f(self._variables[k], *args, **kwargs) - if keep_attrs: - ds._attrs = self._attrs - return ds + variables = OrderedDict() + for k, v in self._variables.items(): + if k in self._coord_names: + variables[k] = v + else: + variables[k] = f(v, *args, **kwargs) + attrs = self._attrs if keep_attrs else None + return self._replace_with_new_dims( + variables, attrs=attrs, encoding=None) return func diff --git a/xarray/core/merge.py b/xarray/core/merge.py index c2c6aee7c22..94a5d4af79a 100644 --- a/xarray/core/merge.py +++ b/xarray/core/merge.py @@ -473,7 +473,7 @@ def merge_core(objs, 'coordinates or not in the merged result: %s' % ambiguous_coords) - return variables, coord_names, dict(dims) + return variables, coord_names, dims def merge(objects, compat='no_conflicts', join='outer', fill_value=dtypes.NA): diff --git a/xarray/testing.py b/xarray/testing.py index ed015181dfd..42c91b1eda2 100644 --- a/xarray/testing.py +++ b/xarray/testing.py @@ -1,10 +1,15 @@ """Testing functions exposed to the user API""" from collections import OrderedDict +from typing import Hashable, Union import numpy as np import pandas as pd -from xarray.core import duck_array_ops, formatting +from xarray.core import duck_array_ops +from xarray.core import formatting +from xarray.core.dataarray import DataArray +from xarray.core.dataset import Dataset +from xarray.core.variable import IndexVariable, Variable from xarray.core.indexes import default_indexes @@ -48,12 +53,11 @@ def assert_equal(a, b): assert_identical, assert_allclose, Dataset.equals, DataArray.equals, numpy.testing.assert_array_equal """ - import xarray as xr __tracebackhide__ = True # noqa: F841 assert type(a) == type(b) # noqa - if isinstance(a, (xr.Variable, xr.DataArray)): + if isinstance(a, (Variable, DataArray)): assert a.equals(b), formatting.diff_array_repr(a, b, 'equals') - elif isinstance(a, xr.Dataset): + elif isinstance(a, Dataset): assert a.equals(b), formatting.diff_dataset_repr(a, b, 'equals') else: raise TypeError('{} not supported by assertion comparison' @@ -77,15 +81,14 @@ def assert_identical(a, b): -------- assert_equal, assert_allclose, Dataset.equals, DataArray.equals """ - import xarray as xr __tracebackhide__ = True # noqa: F841 assert type(a) == type(b) # noqa - if isinstance(a, xr.Variable): + if isinstance(a, Variable): assert a.identical(b), formatting.diff_array_repr(a, b, 'identical') - elif isinstance(a, xr.DataArray): + elif isinstance(a, DataArray): assert a.name == b.name assert a.identical(b), formatting.diff_array_repr(a, b, 'identical') - elif isinstance(a, (xr.Dataset, xr.Variable)): + elif isinstance(a, (Dataset, Variable)): assert a.identical(b), formatting.diff_dataset_repr(a, b, 'identical') else: raise TypeError('{} not supported by assertion comparison' @@ -117,15 +120,14 @@ def assert_allclose(a, b, rtol=1e-05, atol=1e-08, decode_bytes=True): -------- assert_identical, assert_equal, numpy.testing.assert_allclose """ - import xarray as xr __tracebackhide__ = True # noqa: F841 assert type(a) == type(b) # noqa kwargs = dict(rtol=rtol, atol=atol, decode_bytes=decode_bytes) - if isinstance(a, xr.Variable): + if isinstance(a, Variable): assert a.dims == b.dims allclose = _data_allclose_or_equiv(a.values, b.values, **kwargs) assert allclose, '{}\n{}'.format(a.values, b.values) - elif isinstance(a, xr.DataArray): + elif isinstance(a, DataArray): assert_allclose(a.variable, b.variable, **kwargs) assert set(a.coords) == set(b.coords) for v in a.coords.variables: @@ -135,7 +137,7 @@ def assert_allclose(a, b, rtol=1e-05, atol=1e-08, decode_bytes=True): b.coords[v].values, **kwargs) assert allclose, '{}\n{}'.format(a.coords[v].values, b.coords[v].values) - elif isinstance(a, xr.Dataset): + elif isinstance(a, Dataset): assert set(a.data_vars) == set(b.data_vars) assert set(a.coords) == set(b.coords) for k in list(a.variables) + list(a.coords): @@ -147,14 +149,12 @@ def assert_allclose(a, b, rtol=1e-05, atol=1e-08, decode_bytes=True): def _assert_indexes_invariants_checks(indexes, possible_coord_variables, dims): - import xarray as xr - assert isinstance(indexes, OrderedDict), indexes assert all(isinstance(v, pd.Index) for v in indexes.values()), \ {k: type(v) for k, v in indexes.items()} index_vars = {k for k, v in possible_coord_variables.items() - if isinstance(v, xr.IndexVariable)} + if isinstance(v, IndexVariable)} assert indexes.keys() <= index_vars, (set(indexes), index_vars) # Note: when we support non-default indexes, these checks should be opt-in @@ -166,17 +166,97 @@ def _assert_indexes_invariants_checks(indexes, possible_coord_variables, dims): (indexes, defaults) -def _assert_indexes_invariants(a): - """Separate helper function for checking indexes invariants only.""" - import xarray as xr - - if isinstance(a, xr.DataArray): - if a._indexes is not None: - _assert_indexes_invariants_checks(a._indexes, a._coords, a.dims) - elif isinstance(a, xr.Dataset): - if a._indexes is not None: - _assert_indexes_invariants_checks( - a._indexes, a._variables, a._dims) - elif isinstance(a, xr.Variable): - # no indexes - pass +def _assert_variable_invariants(var: Variable, name: Hashable = None): + if name is None: + name_or_empty = () # type: tuple + else: + name_or_empty = (name,) + assert isinstance(var._dims, tuple), name_or_empty + (var._dims,) + assert len(var._dims) == len(var._data.shape), \ + name_or_empty + (var._dims, var._data.shape) + assert isinstance(var._encoding, (type(None), dict)), \ + name_or_empty + (var._encoding,) + assert isinstance(var._attrs, (type(None), OrderedDict)), \ + name_or_empty + (var._attrs,) + + +def _assert_dataarray_invariants(da: DataArray): + assert isinstance(da._variable, Variable), da._variable + _assert_variable_invariants(da._variable) + + assert isinstance(da._coords, OrderedDict), da._coords + assert all( + isinstance(v, Variable) for v in da._coords.values()), da._coords + assert all(set(v.dims) <= set(da.dims) for v in da._coords.values()), \ + (da.dims, {k: v.dims for k, v in da._coords.items()}) + assert all(isinstance(v, IndexVariable) + for (k, v) in da._coords.items() + if v.dims == (k,)), \ + {k: type(v) for k, v in da._coords.items()} + for k, v in da._coords.items(): + _assert_variable_invariants(v, k) + + if da._indexes is not None: + _assert_indexes_invariants_checks(da._indexes, da._coords, da.dims) + + assert da._initialized is True + + +def _assert_dataset_invariants(ds: Dataset): + assert isinstance(ds._variables, OrderedDict), type(ds._variables) + assert all( + isinstance(v, Variable) for v in ds._variables.values()), \ + ds._variables + for k, v in ds._variables.items(): + _assert_variable_invariants(v, k) + + assert isinstance(ds._coord_names, set), ds._coord_names + assert ds._coord_names <= ds._variables.keys(), \ + (ds._coord_names, set(ds._variables)) + + assert type(ds._dims) is dict, ds._dims + assert all(isinstance(v, int) for v in ds._dims.values()), ds._dims + var_dims = set() # type: set + for v in ds._variables.values(): + var_dims.update(v.dims) + assert ds._dims.keys() == var_dims, (set(ds._dims), var_dims) + assert all(ds._dims[k] == v.sizes[k] + for v in ds._variables.values() + for k in v.sizes), \ + (ds._dims, {k: v.sizes for k, v in ds._variables.items()}) + assert all(isinstance(v, IndexVariable) + for (k, v) in ds._variables.items() + if v.dims == (k,)), \ + {k: type(v) for k, v in ds._variables.items() if v.dims == (k,)} + assert all(v.dims == (k,) + for (k, v) in ds._variables.items() + if k in ds._dims), \ + {k: v.dims for k, v in ds._variables.items() if k in ds._dims} + + if ds._indexes is not None: + _assert_indexes_invariants_checks(ds._indexes, ds._variables, ds._dims) + + assert isinstance(ds._encoding, (type(None), dict)) + assert isinstance(ds._attrs, (type(None), OrderedDict)) + assert ds._initialized is True + + +def _assert_internal_invariants( + xarray_obj: Union[DataArray, Dataset, Variable], +): + """Validate that an xarray object satisfies its own internal invariants. + + This exists for the benefit of xarray's own test suite, but may be useful + in external projects if they (ill-advisedly) create objects using xarray's + private APIs. + """ + if isinstance(xarray_obj, Variable): + _assert_variable_invariants(xarray_obj) + elif isinstance(xarray_obj, DataArray): + _assert_dataarray_invariants(xarray_obj) + elif isinstance(xarray_obj, Dataset): + _assert_dataset_invariants(xarray_obj) + else: + raise TypeError( + '{} is not a supported type for xarray invariant checks' + .format(type(xarray_obj))) diff --git a/xarray/tests/__init__.py b/xarray/tests/__init__.py index d3fe5e167a6..dc8b26e4524 100644 --- a/xarray/tests/__init__.py +++ b/xarray/tests/__init__.py @@ -168,21 +168,20 @@ def source_ndarray(array): # Internal versions of xarray's test functions that validate additional # invariants -# TODO: add more invariant checks. def assert_equal(a, b): xarray.testing.assert_equal(a, b) - xarray.testing._assert_indexes_invariants(a) - xarray.testing._assert_indexes_invariants(b) + xarray.testing._assert_internal_invariants(a) + xarray.testing._assert_internal_invariants(b) def assert_identical(a, b): xarray.testing.assert_identical(a, b) - xarray.testing._assert_indexes_invariants(a) - xarray.testing._assert_indexes_invariants(b) + xarray.testing._assert_internal_invariants(a) + xarray.testing._assert_internal_invariants(b) def assert_allclose(a, b, **kwargs): xarray.testing.assert_allclose(a, b, **kwargs) - xarray.testing._assert_indexes_invariants(a) - xarray.testing._assert_indexes_invariants(b) + xarray.testing._assert_internal_invariants(a) + xarray.testing._assert_internal_invariants(b) diff --git a/xarray/tests/test_dataset.py b/xarray/tests/test_dataset.py index 812e2893db5..5aae56485ce 100644 --- a/xarray/tests/test_dataset.py +++ b/xarray/tests/test_dataset.py @@ -2752,6 +2752,11 @@ def test_delitem(self): assert set(data.variables) == all_items - set(['var1', 'numbers']) assert 'numbers' not in data.coords + expected = Dataset() + actual = Dataset({'y': ('x', [1, 2])}) + del actual['y'] + assert_identical(expected, actual) + def test_squeeze(self): data = Dataset({'foo': (['x', 'y', 'z'], [[[1], [2]]])}) for args in [[], [['x']], [['x', 'z']]]: From 9c0bbf744a5235b4187f87de49175e6776d813cb Mon Sep 17 00:00:00 2001 From: Andrew Ross <5852283+andrew-c-ross@users.noreply.github.com> Date: Thu, 20 Jun 2019 11:47:59 -0400 Subject: [PATCH 25/31] Add "errors" keyword argument to drop() and drop_dims() (#2994) (#3028) * Add "errors" keyword argument (GH2994) Adds an errors keyword to Dataset.drop(), Dataset.drop_dims(), and DataArray.drop() (GH2994). Consistent with pandas, the value can be either "raise" or "ignore" * Fix quotes * Different pandas versions raise different errors * Error messages also vary * Correct doc for DataArray.drop; array, not dataset * Require errors argument to be passed with a keyword --- doc/whats-new.rst | 5 +++++ xarray/core/dataarray.py | 10 ++++++--- xarray/core/dataset.py | 37 +++++++++++++++++++++++++--------- xarray/tests/test_dataarray.py | 17 +++++++++++++++- xarray/tests/test_dataset.py | 35 ++++++++++++++++++++++++++++++++ 5 files changed, 90 insertions(+), 14 deletions(-) diff --git a/doc/whats-new.rst b/doc/whats-new.rst index e62c7e87d44..ca50856a25e 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -59,6 +59,11 @@ Enhancements formatted datetimes. By `Alan Brammer `_. - Add ``.str`` accessor to DataArrays for string related manipulations. By `0x0L `_. +- Add ``errors`` keyword argument to :py:meth:`Dataset.drop` and :py:meth:`Dataset.drop_dims` + that allows ignoring errors if a passed label or dimension is not in the dataset + (:issue:`2994`). + By `Andrew Ross `_. + Bug fixes ~~~~~~~~~ diff --git a/xarray/core/dataarray.py b/xarray/core/dataarray.py index 2746c32a8dc..4c3dcc2781a 100644 --- a/xarray/core/dataarray.py +++ b/xarray/core/dataarray.py @@ -1462,7 +1462,7 @@ def transpose(self, *dims, transpose_coords=None) -> 'DataArray': def T(self) -> 'DataArray': return self.transpose() - def drop(self, labels, dim=None): + def drop(self, labels, dim=None, *, errors='raise'): """Drop coordinates or index labels from this DataArray. Parameters @@ -1472,14 +1472,18 @@ def drop(self, labels, dim=None): dim : str, optional Dimension along which to drop index labels. By default (if ``dim is None``), drops coordinates rather than index labels. - + errors: {'raise', 'ignore'}, optional + If 'raise' (default), raises a ValueError error if + any of the coordinates or index labels passed are not + in the array. If 'ignore', any given labels that are in the + array are dropped and no error is raised. Returns ------- dropped : DataArray """ if utils.is_scalar(labels): labels = [labels] - ds = self._to_temp_dataset().drop(labels, dim) + ds = self._to_temp_dataset().drop(labels, dim, errors=errors) return self._from_temp_dataset(ds) def dropna(self, dim, how='any', thresh=None): diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index 026be5ba4b0..13a6a6ee9b2 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -2826,7 +2826,7 @@ def _assert_all_in_dataset(self, names, virtual_okay=False): raise ValueError('One or more of the specified variables ' 'cannot be found in this dataset') - def drop(self, labels, dim=None): + def drop(self, labels, dim=None, *, errors='raise'): """Drop variables or index labels from this dataset. Parameters @@ -2836,33 +2836,41 @@ def drop(self, labels, dim=None): dim : None or str, optional Dimension along which to drop index labels. By default (if ``dim is None``), drops variables rather than index labels. + errors: {'raise', 'ignore'}, optional + If 'raise' (default), raises a ValueError error if + any of the variable or index labels passed are not + in the dataset. If 'ignore', any given labels that are in the + dataset are dropped and no error is raised. Returns ------- dropped : Dataset """ + if errors not in ['raise', 'ignore']: + raise ValueError('errors must be either "raise" or "ignore"') if utils.is_scalar(labels): labels = [labels] if dim is None: - return self._drop_vars(labels) + return self._drop_vars(labels, errors=errors) else: try: index = self.indexes[dim] except KeyError: raise ValueError( 'dimension %r does not have coordinate labels' % dim) - new_index = index.drop(labels) + new_index = index.drop(labels, errors=errors) return self.loc[{dim: new_index}] - def _drop_vars(self, names): - self._assert_all_in_dataset(names) + def _drop_vars(self, names, errors='raise'): + if errors == 'raise': + self._assert_all_in_dataset(names) drop = set(names) variables = OrderedDict((k, v) for k, v in self._variables.items() if k not in drop) coord_names = set(k for k in self._coord_names if k in variables) return self._replace_vars_and_dims(variables, coord_names) - def drop_dims(self, drop_dims): + def drop_dims(self, drop_dims, *, errors='raise'): """Drop dimensions and associated variables from this dataset. Parameters @@ -2875,14 +2883,23 @@ def drop_dims(self, drop_dims): obj : Dataset The dataset without the given dimensions (or any variables containing those dimensions) + errors: {'raise', 'ignore'}, optional + If 'raise' (default), raises a ValueError error if + any of the dimensions passed are not + in the dataset. If 'ignore', any given dimensions that are in the + dataset are dropped and no error is raised. """ + if errors not in ['raise', 'ignore']: + raise ValueError('errors must be either "raise" or "ignore"') + if utils.is_scalar(drop_dims): drop_dims = [drop_dims] - missing_dimensions = [d for d in drop_dims if d not in self.dims] - if missing_dimensions: - raise ValueError('Dataset does not contain the dimensions: %s' - % missing_dimensions) + if errors == 'raise': + missing_dimensions = [d for d in drop_dims if d not in self.dims] + if missing_dimensions: + raise ValueError('Dataset does not contain the dimensions: %s' + % missing_dimensions) drop_vars = set(k for k, v in self._variables.items() for d in v.dims if d in drop_dims) diff --git a/xarray/tests/test_dataarray.py b/xarray/tests/test_dataarray.py index fd9076e7f65..a8825055479 100644 --- a/xarray/tests/test_dataarray.py +++ b/xarray/tests/test_dataarray.py @@ -1859,19 +1859,34 @@ def test_drop_coordinates(self): with pytest.raises(ValueError): arr.drop('not found') + actual = expected.drop('not found', errors='ignore') + assert_identical(actual, expected) + with raises_regex(ValueError, 'cannot be found'): arr.drop(None) + actual = expected.drop(None, errors='ignore') + assert_identical(actual, expected) + renamed = arr.rename('foo') with raises_regex(ValueError, 'cannot be found'): renamed.drop('foo') + actual = renamed.drop('foo', errors='ignore') + assert_identical(actual, renamed) + def test_drop_index_labels(self): arr = DataArray(np.random.randn(2, 3), coords={'y': [0, 1, 2]}, dims=['x', 'y']) actual = arr.drop([0, 1], dim='y') expected = arr[:, 2:] - assert_identical(expected, actual) + assert_identical(actual, expected) + + with raises_regex((KeyError, ValueError), 'not .* in axis'): + actual = arr.drop([0, 1, 3], dim='y') + + actual = arr.drop([0, 1, 3], dim='y', errors='ignore') + assert_identical(actual, expected) def test_dropna(self): x = np.random.randn(4, 4) diff --git a/xarray/tests/test_dataset.py b/xarray/tests/test_dataset.py index 5aae56485ce..8cd129e35de 100644 --- a/xarray/tests/test_dataset.py +++ b/xarray/tests/test_dataset.py @@ -1889,6 +1889,15 @@ def test_drop_variables(self): with raises_regex(ValueError, 'cannot be found'): data.drop('not_found_here') + actual = data.drop('not_found_here', errors='ignore') + assert_identical(data, actual) + + actual = data.drop(['not_found_here'], errors='ignore') + assert_identical(data, actual) + + actual = data.drop(['time', 'not_found_here'], errors='ignore') + assert_identical(expected, actual) + def test_drop_index_labels(self): data = Dataset({'A': (['x', 'y'], np.random.randn(2, 3)), 'x': ['a', 'b']}) @@ -1907,6 +1916,16 @@ def test_drop_index_labels(self): # not contained in axis data.drop(['c'], dim='x') + actual = data.drop(['c'], dim='x', errors='ignore') + assert_identical(data, actual) + + with pytest.raises(ValueError): + data.drop(['c'], dim='x', errors='wrong_value') + + actual = data.drop(['a', 'b', 'c'], 'x', errors='ignore') + expected = data.isel(x=slice(0, 0)) + assert_identical(expected, actual) + with raises_regex( ValueError, 'does not have coordinate labels'): data.drop(1, 'y') @@ -1931,6 +1950,22 @@ def test_drop_dims(self): with pytest.raises((ValueError, KeyError)): data.drop_dims('z') # not a dimension + with pytest.raises((ValueError, KeyError)): + data.drop_dims(None) + + actual = data.drop_dims('z', errors='ignore') + assert_identical(data, actual) + + actual = data.drop_dims(None, errors='ignore') + assert_identical(data, actual) + + with pytest.raises(ValueError): + actual = data.drop_dims('z', errors='wrong_value') + + actual = data.drop_dims(['x', 'y', 'z'], errors='ignore') + expected = data.drop(['A', 'B', 'x']) + assert_identical(expected, actual) + def test_copy(self): data = create_test_data() data.attrs['Test'] = [1, 2, 3] From 724ad8301c9177f0f18476c03f4843006a4de691 Mon Sep 17 00:00:00 2001 From: Deepak Cherian Date: Sat, 22 Jun 2019 13:16:35 -0400 Subject: [PATCH 26/31] Revert cmap fix (#3038) * Test for levels + provided cmap * Revert "plot: If provided with colormap do not modify it. (#2935)" This reverts commit ab3972294860447f9515c7b7b0a04838db061496. * lint --- doc/whats-new.rst | 2 -- xarray/plot/utils.py | 3 +-- xarray/tests/test_plot.py | 31 +++++++------------------------ 3 files changed, 8 insertions(+), 28 deletions(-) diff --git a/doc/whats-new.rst b/doc/whats-new.rst index ca50856a25e..52fa102f7fa 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -75,8 +75,6 @@ Bug fixes By `Mayeul d'Avezac `_. - Return correct count for scalar datetime64 arrays (:issue:`2770`) By `Dan Nowacki `_. -- Fix facetgrid colormap bug when ``extend=True``. (:issue:`2932`) - By `Deepak Cherian `_. - Increased support for `missing_value` (:issue:`2871`) diff --git a/xarray/plot/utils.py b/xarray/plot/utils.py index 18215479d8c..c9f72b177c6 100644 --- a/xarray/plot/utils.py +++ b/xarray/plot/utils.py @@ -264,8 +264,7 @@ def _determine_cmap_params(plot_data, vmin=None, vmax=None, cmap=None, if extend is None: extend = _determine_extend(calc_data, vmin, vmax) - if ((levels is not None or isinstance(norm, mpl.colors.BoundaryNorm)) - and (not isinstance(cmap, mpl.colors.Colormap))): + if levels is not None or isinstance(norm, mpl.colors.BoundaryNorm): cmap, newnorm = _build_discrete_cmap(cmap, levels, extend, filled) norm = newnorm if norm is None else norm diff --git a/xarray/tests/test_plot.py b/xarray/tests/test_plot.py index a0952cac47e..0dc5fb320f0 100644 --- a/xarray/tests/test_plot.py +++ b/xarray/tests/test_plot.py @@ -8,6 +8,7 @@ import xarray as xr import xarray.plot as xplt from xarray import DataArray +from xarray.coding.times import _import_cftime from xarray.plot.plot import _infer_interval_breaks from xarray.plot.utils import ( _build_discrete_cmap, _color_palette, _determine_cmap_params, @@ -537,25 +538,6 @@ def test_cmap_sequential_option(self): cmap_params = _determine_cmap_params(self.data) assert cmap_params['cmap'] == 'magma' - def test_do_nothing_if_provided_cmap(self): - cmap_list = [ - mpl.colors.LinearSegmentedColormap.from_list('name', ['r', 'g']), - mpl.colors.ListedColormap(['r', 'g', 'b']) - ] - - # can't parametrize with mpl objects when mpl is absent - for cmap in cmap_list: - cmap_params = _determine_cmap_params(self.data, - cmap=cmap, - levels=7) - assert cmap_params['cmap'] is cmap - - def test_do_something_if_provided_str_cmap(self): - cmap = 'RdBu_r' - cmap_params = _determine_cmap_params(self.data, cmap=cmap, levels=7) - assert cmap_params['cmap'] is not cmap - assert isinstance(cmap_params['cmap'], mpl.colors.ListedColormap) - def test_cmap_sequential_explicit_option(self): with xr.set_options(cmap_sequential=mpl.cm.magma): cmap_params = _determine_cmap_params(self.data) @@ -775,13 +757,14 @@ def test_discrete_colormap_list_of_levels(self): @pytest.mark.slow def test_discrete_colormap_int_levels(self): - for extend, levels, vmin, vmax in [('neither', 7, None, None), - ('neither', 7, None, 20), - ('both', 7, 4, 8), - ('min', 10, 4, 15)]: + for extend, levels, vmin, vmax, cmap in [ + ('neither', 7, None, None, None), + ('neither', 7, None, 20, mpl.cm.RdBu), + ('both', 7, 4, 8, None), + ('min', 10, 4, 15, None)]: for kind in ['imshow', 'pcolormesh', 'contourf', 'contour']: primitive = getattr(self.darray.plot, kind)( - levels=levels, vmin=vmin, vmax=vmax) + levels=levels, vmin=vmin, vmax=vmax, cmap=cmap) assert levels >= \ len(primitive.norm.boundaries) - 1 if vmax is None: From ff4198865b42ee2f6f99f3f6f83fed68ef4ffbc7 Mon Sep 17 00:00:00 2001 From: Scott Wales Date: Sun, 23 Jun 2019 19:18:32 +1000 Subject: [PATCH 27/31] ENH: keepdims=True for xarray reductions (#3033) * ENH: keepdims=True for xarray reductions Addresses #2170 Add new option `keepdims` to xarray reduce operations, following the behaviour of Numpy. `keepdims` may be passed to reductions on either Datasets or DataArrays, and will result in the reduced dimensions being still present in the output with size 1. Coordinates that depend on the reduced dimensions will be removed from the Dataset/DataArray * Set the default to be `False` * Correct lint error * Apply suggestions from code review Co-Authored-By: Maximilian Roos <5635139+max-sixty@users.noreply.github.com> * Add test for dask and fix implementation * Move 'keepdims' up to where 'dims' is set * Fix lint, add test for scalar variable --- doc/whats-new.rst | 2 ++ xarray/core/dataarray.py | 18 +++++++++++++--- xarray/core/dataset.py | 9 ++++++-- xarray/core/variable.py | 20 +++++++++++++++--- xarray/tests/test_dataarray.py | 38 ++++++++++++++++++++++++++++++++++ xarray/tests/test_dataset.py | 19 +++++++++++++++++ xarray/tests/test_variable.py | 36 ++++++++++++++++++++++++++++++++ 7 files changed, 134 insertions(+), 8 deletions(-) diff --git a/doc/whats-new.rst b/doc/whats-new.rst index 52fa102f7fa..373cb8d13dc 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -22,6 +22,8 @@ Enhancements ~~~~~~~~~~~~ +- Add ``keepdims`` argument for reduce operations (:issue:`2170`) + By `Scott Wales `_. - netCDF chunksizes are now only dropped when original_shape is different, not when it isn't found. (:issue:`2207`) By `Karel van de Plassche `_. diff --git a/xarray/core/dataarray.py b/xarray/core/dataarray.py index 4c3dcc2781a..ff77a6ab704 100644 --- a/xarray/core/dataarray.py +++ b/xarray/core/dataarray.py @@ -259,8 +259,14 @@ def _replace(self, variable=None, coords=None, name=__default): return type(self)(variable, coords, name=name, fastpath=True) def _replace_maybe_drop_dims(self, variable, name=__default): - if variable.dims == self.dims: + if variable.dims == self.dims and variable.shape == self.shape: coords = self._coords.copy() + elif variable.dims == self.dims: + # Shape has changed (e.g. from reduce(..., keepdims=True) + new_sizes = dict(zip(self.dims, variable.shape)) + coords = OrderedDict((k, v) for k, v in self._coords.items() + if v.shape == tuple(new_sizes[d] + for d in v.dims)) else: allowed_dims = set(variable.dims) coords = OrderedDict((k, v) for k, v in self._coords.items() @@ -1642,7 +1648,8 @@ def combine_first(self, other): """ return ops.fillna(self, other, join="outer") - def reduce(self, func, dim=None, axis=None, keep_attrs=None, **kwargs): + def reduce(self, func, dim=None, axis=None, keep_attrs=None, + keepdims=False, **kwargs): """Reduce this array by applying `func` along some dimension(s). Parameters @@ -1662,6 +1669,10 @@ def reduce(self, func, dim=None, axis=None, keep_attrs=None, **kwargs): If True, the variable's attributes (`attrs`) will be copied from the original object to the new one. If False (default), the new object will be returned without attributes. + keepdims : bool, default False + If True, the dimensions which are reduced are left in the result + as dimensions of size one. Coordinates that use these dimensions + are removed. **kwargs : dict Additional keyword arguments passed on to `func`. @@ -1672,7 +1683,8 @@ def reduce(self, func, dim=None, axis=None, keep_attrs=None, **kwargs): summarized data and the indicated dimension(s) removed. """ - var = self.variable.reduce(func, dim, axis, keep_attrs, **kwargs) + var = self.variable.reduce(func, dim, axis, keep_attrs, keepdims, + **kwargs) return self._replace_maybe_drop_dims(var) def to_pandas(self): diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index 13a6a6ee9b2..3e00640ba60 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -3152,8 +3152,8 @@ def combine_first(self, other): out = ops.fillna(self, other, join="outer", dataset_join="outer") return out - def reduce(self, func, dim=None, keep_attrs=None, numeric_only=False, - allow_lazy=False, **kwargs): + def reduce(self, func, dim=None, keep_attrs=None, keepdims=False, + numeric_only=False, allow_lazy=False, **kwargs): """Reduce this dataset by applying `func` along some dimension(s). Parameters @@ -3169,6 +3169,10 @@ def reduce(self, func, dim=None, keep_attrs=None, numeric_only=False, If True, the dataset's attributes (`attrs`) will be copied from the original object to the new one. If False (default), the new object will be returned without attributes. + keepdims : bool, default False + If True, the dimensions which are reduced are left in the result + as dimensions of size one. Coordinates that use these dimensions + are removed. numeric_only : bool, optional If True, only apply ``func`` to variables with a numeric dtype. **kwargs : dict @@ -3218,6 +3222,7 @@ def reduce(self, func, dim=None, keep_attrs=None, numeric_only=False, reduce_dims = None variables[name] = var.reduce(func, dim=reduce_dims, keep_attrs=keep_attrs, + keepdims=keepdims, allow_lazy=allow_lazy, **kwargs) diff --git a/xarray/core/variable.py b/xarray/core/variable.py index 41f8795b595..ab1be181e31 100644 --- a/xarray/core/variable.py +++ b/xarray/core/variable.py @@ -1334,7 +1334,7 @@ def where(self, cond, other=dtypes.NA): return ops.where_method(self, cond, other) def reduce(self, func, dim=None, axis=None, - keep_attrs=None, allow_lazy=False, **kwargs): + keep_attrs=None, keepdims=False, allow_lazy=False, **kwargs): """Reduce this array by applying `func` along some dimension(s). Parameters @@ -1354,6 +1354,9 @@ def reduce(self, func, dim=None, axis=None, If True, the variable's attributes (`attrs`) will be copied from the original object to the new one. If False (default), the new object will be returned without attributes. + keepdims : bool, default False + If True, the dimensions which are reduced are left in the result + as dimensions of size one **kwargs : dict Additional keyword arguments passed on to `func`. @@ -1381,8 +1384,19 @@ def reduce(self, func, dim=None, axis=None, else: removed_axes = (range(self.ndim) if axis is None else np.atleast_1d(axis) % self.ndim) - dims = [adim for n, adim in enumerate(self.dims) - if n not in removed_axes] + if keepdims: + # Insert np.newaxis for removed dims + slices = tuple(np.newaxis if i in removed_axes else + slice(None, None) for i in range(self.ndim)) + if getattr(data, 'shape', None) is None: + # Reduce has produced a scalar value, not an array-like + data = np.asanyarray(data)[slices] + else: + data = data[slices] + dims = self.dims + else: + dims = [adim for n, adim in enumerate(self.dims) + if n not in removed_axes] if keep_attrs is None: keep_attrs = _get_keep_attrs(default=False) diff --git a/xarray/tests/test_dataarray.py b/xarray/tests/test_dataarray.py index a8825055479..47222194151 100644 --- a/xarray/tests/test_dataarray.py +++ b/xarray/tests/test_dataarray.py @@ -1991,6 +1991,44 @@ def test_reduce(self): dims=['x', 'y']).mean('x') assert_equal(actual, expected) + def test_reduce_keepdims(self): + coords = {'x': [-1, -2], 'y': ['ab', 'cd', 'ef'], + 'lat': (['x', 'y'], [[1, 2, 3], [-1, -2, -3]]), + 'c': -999} + orig = DataArray([[-1, 0, 1], [-3, 0, 3]], coords, dims=['x', 'y']) + + # Mean on all axes loses non-constant coordinates + actual = orig.mean(keepdims=True) + expected = DataArray(orig.data.mean(keepdims=True), dims=orig.dims, + coords={k: v for k, v in coords.items() + if k in ['c']}) + assert_equal(actual, expected) + + assert actual.sizes['x'] == 1 + assert actual.sizes['y'] == 1 + + # Mean on specific axes loses coordinates not involving that axis + actual = orig.mean('y', keepdims=True) + expected = DataArray(orig.data.mean(axis=1, keepdims=True), + dims=orig.dims, + coords={k: v for k, v in coords.items() + if k not in ['y', 'lat']}) + assert_equal(actual, expected) + + @requires_bottleneck + def test_reduce_keepdims_bottleneck(self): + import bottleneck + + coords = {'x': [-1, -2], 'y': ['ab', 'cd', 'ef'], + 'lat': (['x', 'y'], [[1, 2, 3], [-1, -2, -3]]), + 'c': -999} + orig = DataArray([[-1, 0, 1], [-3, 0, 3]], coords, dims=['x', 'y']) + + # Bottleneck does not have its own keepdims implementation + actual = orig.reduce(bottleneck.nanmean, keepdims=True) + expected = orig.mean(keepdims=True) + assert_equal(actual, expected) + def test_reduce_dtype(self): coords = {'x': [-1, -2], 'y': ['ab', 'cd', 'ef'], 'lat': (['x', 'y'], [[1, 2, 3], [-1, -2, -3]]), diff --git a/xarray/tests/test_dataset.py b/xarray/tests/test_dataset.py index 8cd129e35de..e3a01bbd3a1 100644 --- a/xarray/tests/test_dataset.py +++ b/xarray/tests/test_dataset.py @@ -3898,6 +3898,25 @@ def total_sum(x): with raises_regex(TypeError, "unexpected keyword argument 'axis'"): ds.reduce(total_sum, dim='x') + def test_reduce_keepdims(self): + ds = Dataset({'a': (['x', 'y'], [[0, 1, 2, 3, 4]])}, + coords={'y': [0, 1, 2, 3, 4], 'x': [0], + 'lat': (['x', 'y'], [[0, 1, 2, 3, 4]]), + 'c': -999.0}) + + # Shape should match behaviour of numpy reductions with keepdims=True + # Coordinates involved in the reduction should be removed + actual = ds.mean(keepdims=True) + expected = Dataset({'a': (['x', 'y'], np.mean(ds.a, keepdims=True))}, + coords={'c': ds.c}) + assert_identical(expected, actual) + + actual = ds.mean('x', keepdims=True) + expected = Dataset({'a': (['x', 'y'], + np.mean(ds.a, axis=0, keepdims=True))}, + coords={'y': ds.y, 'c': ds.c}) + assert_identical(expected, actual) + def test_quantile(self): ds = create_test_data(seed=123) diff --git a/xarray/tests/test_variable.py b/xarray/tests/test_variable.py index 4ddd114d767..5da83880539 100644 --- a/xarray/tests/test_variable.py +++ b/xarray/tests/test_variable.py @@ -1540,6 +1540,42 @@ def test_reduce_funcs(self): assert_identical( v.max(), Variable([], pd.Timestamp('2000-01-03'))) + def test_reduce_keepdims(self): + v = Variable(['x', 'y'], self.d) + + assert_identical(v.mean(keepdims=True), + Variable(v.dims, np.mean(self.d, keepdims=True))) + assert_identical(v.mean(dim='x', keepdims=True), + Variable(v.dims, np.mean(self.d, axis=0, + keepdims=True))) + assert_identical(v.mean(dim='y', keepdims=True), + Variable(v.dims, np.mean(self.d, axis=1, + keepdims=True))) + assert_identical(v.mean(dim=['y', 'x'], keepdims=True), + Variable(v.dims, np.mean(self.d, axis=(1, 0), + keepdims=True))) + + v = Variable([], 1.0) + assert_identical(v.mean(keepdims=True), + Variable([], np.mean(v.data, keepdims=True))) + + @requires_dask + def test_reduce_keepdims_dask(self): + import dask.array + v = Variable(['x', 'y'], self.d).chunk() + + actual = v.mean(keepdims=True) + assert isinstance(actual.data, dask.array.Array) + + expected = Variable(v.dims, np.mean(self.d, keepdims=True)) + assert_identical(actual, expected) + + actual = v.mean(dim='y', keepdims=True) + assert isinstance(actual.data, dask.array.Array) + + expected = Variable(v.dims, np.mean(self.d, axis=1, keepdims=True)) + assert_identical(actual, expected) + def test_reduce_keep_attrs(self): _attrs = {'units': 'test', 'long_name': 'testing'} From 56fc325ae029b5aa8fc4556197103a1a2eb31702 Mon Sep 17 00:00:00 2001 From: Maximilian Roos <5635139+max-sixty@users.noreply.github.com> Date: Sun, 23 Jun 2019 07:47:04 -0400 Subject: [PATCH 28/31] add back dask-dev tests (#3025) --- .travis.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.travis.yml b/.travis.yml index ee242ebf818..6f63f305ed2 100644 --- a/.travis.yml +++ b/.travis.yml @@ -19,13 +19,13 @@ matrix: - env: CONDA_ENV=py36-pandas-dev - env: CONDA_ENV=py36-rasterio - env: CONDA_ENV=py36-zarr-dev + - env: CONDA_ENV=py36-dask-dev - env: CONDA_ENV=docs - env: CONDA_ENV=lint - env: CONDA_ENV=typing - env: CONDA_ENV=py36-hypothesis allow_failures: - - env: CONDA_ENV=py36-dask-dev - env: - CONDA_ENV=py36 - EXTRA_FLAGS="--run-flaky --run-network-tests" From 223a05f1b77d4efe8ac7d4dc2c24bff61335693c Mon Sep 17 00:00:00 2001 From: Stephan Hoyer Date: Sun, 23 Jun 2019 19:49:22 +0300 Subject: [PATCH 29/31] Ensure explicitly indexed arrays are preserved (#3027) * Ensure indexing explicitly indexed arrays don't leak out. Previously, indexing an ImplicitToExplicitIndexingAdapter object could directly return an ExplicitlyIndexed object, which could not be indexed normally. This resulted in broken behavior with dask's new `_meta` attribute. This change almost but not entirely fixes xarray on dask master. There are still errors raised inside two tests from dask's `blockwise_meta` helper function: > return meta.astype(dtype) E AttributeError: 'ImplicitToExplicitIndexingAdapter' object has no attribute 'astype' * Set meta in dask.array.from_array --- .travis.yml | 2 +- xarray/core/indexing.py | 8 +++++++- xarray/core/variable.py | 14 +++++++++++++- xarray/tests/test_indexing.py | 9 ++++++++- 4 files changed, 29 insertions(+), 4 deletions(-) diff --git a/.travis.yml b/.travis.yml index 6f63f305ed2..efa903f5083 100644 --- a/.travis.yml +++ b/.travis.yml @@ -16,10 +16,10 @@ matrix: - env: - CONDA_ENV=py36 - EXTRA_FLAGS="--run-flaky --run-network-tests" + - env: CONDA_ENV=py36-dask-dev - env: CONDA_ENV=py36-pandas-dev - env: CONDA_ENV=py36-rasterio - env: CONDA_ENV=py36-zarr-dev - - env: CONDA_ENV=py36-dask-dev - env: CONDA_ENV=docs - env: CONDA_ENV=lint - env: CONDA_ENV=typing diff --git a/xarray/core/indexing.py b/xarray/core/indexing.py index 1effb9347dd..1ba3175dc2f 100644 --- a/xarray/core/indexing.py +++ b/xarray/core/indexing.py @@ -453,7 +453,13 @@ def __array__(self, dtype=None): def __getitem__(self, key): key = expanded_indexer(key, self.ndim) - return self.array[self.indexer_cls(key)] + result = self.array[self.indexer_cls(key)] + if isinstance(result, ExplicitlyIndexed): + return type(self)(result, self.indexer_cls) + else: + # Sometimes explicitly indexed arrays return NumPy arrays or + # scalars. + return result class LazilyOuterIndexedArray(ExplicitlyIndexedNDArrayMixin): diff --git a/xarray/core/variable.py b/xarray/core/variable.py index ab1be181e31..cccb9663ad5 100644 --- a/xarray/core/variable.py +++ b/xarray/core/variable.py @@ -3,6 +3,7 @@ import typing from collections import OrderedDict, defaultdict from datetime import timedelta +from distutils.version import LooseVersion import numpy as np import pandas as pd @@ -870,6 +871,7 @@ def chunk(self, chunks=None, name=None, lock=False): ------- chunked : xarray.Variable """ + import dask import dask.array as da if utils.is_dict_like(chunks): @@ -892,7 +894,17 @@ def chunk(self, chunks=None, name=None, lock=False): # https://github.com/dask/dask/issues/2883 data = indexing.ImplicitToExplicitIndexingAdapter( data, indexing.OuterIndexer) - data = da.from_array(data, chunks, name=name, lock=lock) + + # For now, assume that all arrays that we wrap with dask (including + # our lazily loaded backend array classes) should use NumPy array + # operations. + if LooseVersion(dask.__version__) > '1.2.2': + kwargs = dict(meta=np.ndarray) + else: + kwargs = dict() + + data = da.from_array( + data, chunks, name=name, lock=lock, **kwargs) return type(self)(self.dims, data, self._attrs, self._encoding, fastpath=True) diff --git a/xarray/tests/test_indexing.py b/xarray/tests/test_indexing.py index 9301abb5e32..59435fea88b 100644 --- a/xarray/tests/test_indexing.py +++ b/xarray/tests/test_indexing.py @@ -505,13 +505,20 @@ def test_decompose_indexers(shape, indexer_mode, indexing_support): def test_implicit_indexing_adapter(): - array = np.arange(10) + array = np.arange(10, dtype=np.int64) implicit = indexing.ImplicitToExplicitIndexingAdapter( indexing.NumpyIndexingAdapter(array), indexing.BasicIndexer) np.testing.assert_array_equal(array, np.asarray(implicit)) np.testing.assert_array_equal(array, implicit[:]) +def test_implicit_indexing_adapter_copy_on_write(): + array = np.arange(10, dtype=np.int64) + implicit = indexing.ImplicitToExplicitIndexingAdapter( + indexing.CopyOnWriteArray(array)) + assert isinstance(implicit[:], indexing.ImplicitToExplicitIndexingAdapter) + + def test_outer_indexer_consistency_with_broadcast_indexes_vectorized(): def nonzero(x): if isinstance(x, np.ndarray) and x.dtype.kind == 'b': From cfd821065341e386b3e4a1e6e09bf8d952ed0e2a Mon Sep 17 00:00:00 2001 From: Maximilian Roos <5635139+max-sixty@users.noreply.github.com> Date: Mon, 24 Jun 2019 11:20:31 -0400 Subject: [PATCH 30/31] rolling_exp (nee ewm) (#2650) * WIP on ewm using numbagg * basic functionality, no dims working yet * rename to `rolling_exp` * ensure works on either dimensions * window_type working * add numbagg to travis install * naming * formatting * @shoyer's function to abstract the type of self.obj * initial docstring * add docstrings to docs * example * correct location for docs * add numbagg to print_versions * whatsnew * updating my GH username * pin to numbagg release * rename inner func to move_exp_nanmean * merge * typo * comments from PR * window -> alpha in numbagg * add docs * doc fix * whatsnew update * revert formatting changes to unchanged file * update docstrings, adjust kwarg names * mypy * flake * pytest config tiny tweak while I'm here * Rolling exp doc updates * remove _attributes from RollingExp class --- ci/requirements-py37.yml | 1 + doc/api.rst | 3 + doc/computation.rst | 16 +++++ doc/installing.rst | 2 + doc/whats-new.rst | 20 ++++--- setup.cfg | 11 ++-- xarray/core/common.py | 55 +++++++++++++++-- xarray/core/rolling_exp.py | 106 +++++++++++++++++++++++++++++++++ xarray/tests/__init__.py | 1 + xarray/tests/test_dataarray.py | 49 +++++++++++---- xarray/tests/test_dataset.py | 9 ++- xarray/util/print_versions.py | 1 + 12 files changed, 244 insertions(+), 30 deletions(-) create mode 100644 xarray/core/rolling_exp.py diff --git a/ci/requirements-py37.yml b/ci/requirements-py37.yml index fe5afd589c8..723ad24d24d 100644 --- a/ci/requirements-py37.yml +++ b/ci/requirements-py37.yml @@ -30,3 +30,4 @@ dependencies: - pydap - pip: - mypy==0.650 + - numbagg diff --git a/doc/api.rst b/doc/api.rst index 33c8d9d3ceb..811e3241438 100644 --- a/doc/api.rst +++ b/doc/api.rst @@ -148,6 +148,7 @@ Computation Dataset.groupby Dataset.groupby_bins Dataset.rolling + Dataset.rolling_exp Dataset.coarsen Dataset.resample Dataset.diff @@ -315,6 +316,7 @@ Computation DataArray.groupby DataArray.groupby_bins DataArray.rolling + DataArray.rolling_exp DataArray.coarsen DataArray.dt DataArray.resample @@ -535,6 +537,7 @@ Rolling objects core.rolling.DatasetRolling core.rolling.DatasetRolling.construct core.rolling.DatasetRolling.reduce + core.rolling_exp.RollingExp Resample objects ================ diff --git a/doc/computation.rst b/doc/computation.rst index 3100925a7d3..b06d7959504 100644 --- a/doc/computation.rst +++ b/doc/computation.rst @@ -190,6 +190,22 @@ We can also manually iterate through ``Rolling`` objects: for label, arr_window in r: # arr_window is a view of x +.. _comput.rolling_exp: + +While ``rolling`` provides a simple moving average, ``DataArray`` also supports +an exponential moving average with :py:meth:`~xarray.DataArray.rolling_exp`. +This is similiar to pandas' ``ewm`` method. numbagg_ is required. + +.. _numbagg: https://github.com/shoyer/numbagg + +.. code:: python + + arr.rolling_exp(y=3).mean() + +The ``rolling_exp`` method takes a ``window_type`` kwarg, which can be ``'alpha'``, +``'com'`` (for ``center-of-mass``), ``'span'``, and ``'halflife'``. The default is +``span``. + Finally, the rolling object has a ``construct`` method which returns a view of the original ``DataArray`` with the windowed dimension in the last position. diff --git a/doc/installing.rst b/doc/installing.rst index f624da18611..b9d1b4d0ba4 100644 --- a/doc/installing.rst +++ b/doc/installing.rst @@ -45,6 +45,8 @@ For accelerating xarray - `bottleneck `__: speeds up NaN-skipping and rolling window aggregations by a large factor (1.1 or later) +- `numbagg `_: for exponential rolling + window operations For parallel computing ~~~~~~~~~~~~~~~~~~~~~~ diff --git a/doc/whats-new.rst b/doc/whats-new.rst index 373cb8d13dc..b48614bd0ff 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -32,6 +32,11 @@ Enhancements - Add ``fill_value`` argument for reindex, align, and merge operations to enable custom fill values. (:issue:`2876`) By `Zach Griffith `_. +- :py:meth:`~xarray.DataArray.rolling_exp` and + :py:meth:`~xarray.Dataset.rolling_exp` added, similar to pandas' + ``pd.DataFrame.ewm`` method. Calling ``.mean`` on the resulting object + will return an exponentially weighted moving average. + By `Maximilian Roos `_. - Character arrays' character dimension name decoding and encoding handled by ``var.encoding['char_dim_name']`` (:issue:`2895`) By `James McCreight `_. @@ -188,6 +193,7 @@ Other enhancements - Upsampling an array via interpolation with resample is now dask-compatible, as long as the array is not chunked along the resampling dimension. By `Spencer Clark `_. + - :py:func:`xarray.testing.assert_equal` and :py:func:`xarray.testing.assert_identical` now provide a more detailed report showing what exactly differs between the two objects (dimensions / @@ -737,7 +743,7 @@ Enhancements arguments in ``data_vars`` to indexes set explicitly in ``coords``, where previously an error would be raised. (:issue:`674`) - By `Maximilian Roos `_. + By `Maximilian Roos `_. - :py:meth:`~DataArray.sel`, :py:meth:`~DataArray.isel` & :py:meth:`~DataArray.reindex`, (and their :py:class:`Dataset` counterparts) now support supplying a ``dict`` @@ -745,12 +751,12 @@ Enhancements of supplying `kwargs`. This allows for more robust behavior of dimension names which conflict with other keyword names, or are not strings. - By `Maximilian Roos `_. + By `Maximilian Roos `_. - :py:meth:`~DataArray.rename` now supports supplying ``**kwargs``, as an alternative to the existing approach of supplying a ``dict`` as the first argument. - By `Maximilian Roos `_. + By `Maximilian Roos `_. - :py:meth:`~DataArray.cumsum` and :py:meth:`~DataArray.cumprod` now support aggregation over multiple dimensions at the same time. This is the default @@ -915,7 +921,7 @@ Enhancements which test each value in the array for whether it is contained in the supplied list, returning a bool array. See :ref:`selecting values with isin` for full details. Similar to the ``np.isin`` function. - By `Maximilian Roos `_. + By `Maximilian Roos `_. - Some speed improvement to construct :py:class:`~xarray.DataArrayRolling` object (:issue:`1993`) By `Keisuke Fujii `_. @@ -2110,7 +2116,7 @@ Enhancements ~~~~~~~~~~~~ - New documentation on :ref:`panel transition`. By - `Maximilian Roos `_. + `Maximilian Roos `_. - New ``Dataset`` and ``DataArray`` methods :py:meth:`~xarray.Dataset.to_dict` and :py:meth:`~xarray.Dataset.from_dict` to allow easy conversion between dictionaries and xarray objects (:issue:`432`). See @@ -2131,9 +2137,9 @@ Bug fixes (:issue:`953`). By `Stephan Hoyer `_. - ``Dataset.__dir__()`` (i.e. the method python calls to get autocomplete options) failed if one of the dataset's keys was not a string (:issue:`852`). - By `Maximilian Roos `_. + By `Maximilian Roos `_. - ``Dataset`` constructor can now take arbitrary objects as values - (:issue:`647`). By `Maximilian Roos `_. + (:issue:`647`). By `Maximilian Roos `_. - Clarified ``copy`` argument for :py:meth:`~xarray.DataArray.reindex` and :py:func:`~xarray.align`, which now consistently always return new xarray objects (:issue:`927`). diff --git a/setup.cfg b/setup.cfg index cdfe2ec3e36..bfa49118d84 100644 --- a/setup.cfg +++ b/setup.cfg @@ -9,8 +9,11 @@ filterwarnings = ignore:Using a non-tuple sequence for multidimensional indexing is deprecated:FutureWarning env = UVCDAT_ANONYMOUS_LOG=no +markers = + flaky: flaky tests + network: tests requiring a network connection + slow: slow tests -# This should be kept in sync with .pep8speaks.yml [flake8] max-line-length=79 ignore= @@ -23,10 +26,6 @@ ignore= F401 exclude= doc -markers = - flaky: flaky tests - network: tests requiring a network connection - slow: slow tests [isort] default_section=THIRDPARTY @@ -62,6 +61,8 @@ ignore_missing_imports = True ignore_missing_imports = True [mypy-nc_time_axis.*] ignore_missing_imports = True +[mypy-numbagg.*] +ignore_missing_imports = True [mypy-numpy.*] ignore_missing_imports = True [mypy-netCDF4.*] diff --git a/xarray/core/common.py b/xarray/core/common.py index 4e5133fd8c6..0195be62500 100644 --- a/xarray/core/common.py +++ b/xarray/core/common.py @@ -12,6 +12,7 @@ from .arithmetic import SupportsArithmetic from .options import _get_keep_attrs from .pycompat import dask_array_type +from .rolling_exp import RollingExp from .utils import Frozen, ReprObject, SortedKeysDict, either_dict_or_kwargs # Used as a sentinel value to indicate a all dimensions @@ -86,6 +87,7 @@ def wrapped_func(self, dim=None, **kwargs): # type: ignore class AbstractArray(ImplementsArrayReduce): """Shared base class for DataArray and Variable. """ + def __bool__(self: Any) -> bool: return bool(self.values) @@ -249,6 +251,8 @@ def get_squeeze_dims(xarray_obj, class DataWithCoords(SupportsArithmetic, AttrAccessMixin): """Shared base class for Dataset and DataArray.""" + _rolling_exp_cls = RollingExp + def squeeze(self, dim: Union[Hashable, Iterable[Hashable], None] = None, drop: bool = False, axis: Union[int, Iterable[int], None] = None): @@ -553,7 +557,7 @@ def groupby_bins(self, group, bins, right: bool = True, labels=None, def rolling(self, dim: Optional[Mapping[Hashable, int]] = None, min_periods: Optional[int] = None, center: bool = False, - **dim_kwargs: int): + **window_kwargs: int): """ Rolling window object. @@ -568,9 +572,9 @@ def rolling(self, dim: Optional[Mapping[Hashable, int]] = None, setting min_periods equal to the size of the window. center : boolean, default False Set the labels at the center of the window. - **dim_kwargs : optional + **window_kwargs : optional The keyword arguments form of ``dim``. - One of dim or dim_kwargs must be provided. + One of dim or window_kwargs must be provided. Returns ------- @@ -609,15 +613,54 @@ def rolling(self, dim: Optional[Mapping[Hashable, int]] = None, core.rolling.DataArrayRolling core.rolling.DatasetRolling """ # noqa - dim = either_dict_or_kwargs(dim, dim_kwargs, 'rolling') + dim = either_dict_or_kwargs(dim, window_kwargs, 'rolling') return self._rolling_cls(self, dim, min_periods=min_periods, center=center) + def rolling_exp( + self, + window: Optional[Mapping[Hashable, int]] = None, + window_type: str = 'span', + **window_kwargs + ): + """ + Exponentially-weighted moving window. + Similar to EWM in pandas + + Requires the optional Numbagg dependency. + + Parameters + ---------- + window : A single mapping from a dimension name to window value, + optional + dim : str + Name of the dimension to create the rolling exponential window + along (e.g., `time`). + window : int + Size of the moving window. The type of this is specified in + `window_type` + window_type : str, one of ['span', 'com', 'halflife', 'alpha'], + default 'span' + The format of the previously supplied window. Each is a simple + numerical transformation of the others. Described in detail: + https://pandas.pydata.org/pandas-docs/stable/generated/pandas.DataFrame.ewm.html + **window_kwargs : optional + The keyword arguments form of ``window``. + One of window or window_kwargs must be provided. + + See Also + -------- + core.rolling_exp.RollingExp + """ + window = either_dict_or_kwargs(window, window_kwargs, 'rolling_exp') + + return self._rolling_exp_cls(self, window, window_type) + def coarsen(self, dim: Optional[Mapping[Hashable, int]] = None, boundary: str = 'exact', side: Union[str, Mapping[Hashable, str]] = 'left', coord_func: str = 'mean', - **dim_kwargs: int): + **window_kwargs: int): """ Coarsen object. @@ -671,7 +714,7 @@ def coarsen(self, dim: Optional[Mapping[Hashable, int]] = None, core.rolling.DataArrayCoarsen core.rolling.DatasetCoarsen """ - dim = either_dict_or_kwargs(dim, dim_kwargs, 'coarsen') + dim = either_dict_or_kwargs(dim, window_kwargs, 'coarsen') return self._coarsen_cls( self, dim, boundary=boundary, side=side, coord_func=coord_func) diff --git a/xarray/core/rolling_exp.py b/xarray/core/rolling_exp.py new file mode 100644 index 00000000000..ff6baef5c3a --- /dev/null +++ b/xarray/core/rolling_exp.py @@ -0,0 +1,106 @@ +import numpy as np + +from .pycompat import dask_array_type + + +def _get_alpha(com=None, span=None, halflife=None, alpha=None): + # pandas defines in terms of com (converting to alpha in the algo) + # so use its function to get a com and then convert to alpha + + com = _get_center_of_mass(com, span, halflife, alpha) + return 1 / (1 + com) + + +def move_exp_nanmean(array, *, axis, alpha): + if isinstance(array, dask_array_type): + raise TypeError("rolling_exp is not currently support for dask arrays") + import numbagg + if axis == (): + return array.astype(np.float64) + else: + return numbagg.move_exp_nanmean( + array, axis=axis, alpha=alpha) + + +def _get_center_of_mass(comass, span, halflife, alpha): + """ + Vendored from pandas.core.window._get_center_of_mass + + See licenses/PANDAS_LICENSE for the function's license + """ + from pandas.core import common as com + valid_count = com.count_not_none(comass, span, halflife, alpha) + if valid_count > 1: + raise ValueError("comass, span, halflife, and alpha " + "are mutually exclusive") + + # Convert to center of mass; domain checks ensure 0 < alpha <= 1 + if comass is not None: + if comass < 0: + raise ValueError("comass must satisfy: comass >= 0") + elif span is not None: + if span < 1: + raise ValueError("span must satisfy: span >= 1") + comass = (span - 1) / 2. + elif halflife is not None: + if halflife <= 0: + raise ValueError("halflife must satisfy: halflife > 0") + decay = 1 - np.exp(np.log(0.5) / halflife) + comass = 1 / decay - 1 + elif alpha is not None: + if alpha <= 0 or alpha > 1: + raise ValueError("alpha must satisfy: 0 < alpha <= 1") + comass = (1.0 - alpha) / alpha + else: + raise ValueError("Must pass one of comass, span, halflife, or alpha") + + return float(comass) + + +class RollingExp: + """ + Exponentially-weighted moving window object. + Similar to EWM in pandas + + Parameters + ---------- + obj : Dataset or DataArray + Object to window. + windows : A single mapping from a single dimension name to window value + dim : str + Name of the dimension to create the rolling exponential window + along (e.g., `time`). + window : int + Size of the moving window. The type of this is specified in + `window_type` + window_type : str, one of ['span', 'com', 'halflife', 'alpha'], default 'span' + The format of the previously supplied window. Each is a simple + numerical transformation of the others. Described in detail: + https://pandas.pydata.org/pandas-docs/stable/generated/pandas.DataFrame.ewm.html + + Returns + ------- + RollingExp : type of input argument + """ # noqa + + def __init__(self, obj, windows, window_type='span'): + self.obj = obj + dim, window = next(iter(windows.items())) + self.dim = dim + self.alpha = _get_alpha(**{window_type: window}) + + def mean(self): + """ + Exponentially weighted moving average + + Examples + -------- + >>> da = xr.DataArray([1,1,2,2,2], dims='x') + >>> da.rolling_exp(x=2, window_type='span').mean() + + array([1. , 1. , 1.692308, 1.9 , 1.966942]) + Dimensions without coordinates: x + """ + + return self.obj.reduce( + move_exp_nanmean, dim=self.dim, alpha=self.alpha) diff --git a/xarray/tests/__init__.py b/xarray/tests/__init__.py index dc8b26e4524..81bb1a1e18d 100644 --- a/xarray/tests/__init__.py +++ b/xarray/tests/__init__.py @@ -74,6 +74,7 @@ def LooseVersion(vstring): has_np113, requires_np113 = _importorskip('numpy', minversion='1.13.0') has_iris, requires_iris = _importorskip('iris') has_cfgrib, requires_cfgrib = _importorskip('cfgrib') +has_numbagg, requires_numbagg = _importorskip('numbagg') # some special cases has_h5netcdf07, requires_h5netcdf07 = _importorskip('h5netcdf', diff --git a/xarray/tests/test_dataarray.py b/xarray/tests/test_dataarray.py index 47222194151..b7235629d7a 100644 --- a/xarray/tests/test_dataarray.py +++ b/xarray/tests/test_dataarray.py @@ -20,7 +20,7 @@ LooseVersion, ReturnItem, assert_allclose, assert_array_equal, assert_equal, assert_identical, raises_regex, requires_bottleneck, requires_cftime, requires_dask, requires_iris, requires_np113, - requires_scipy, source_ndarray) + requires_numbagg, requires_scipy, source_ndarray) class TestDataArray: @@ -3957,14 +3957,14 @@ def test_to_and_from_iris(self): assert coord.var_name == original_coord.name assert_array_equal( coord.points, CFDatetimeCoder().encode(original_coord).values) - assert (actual.coord_dims(coord) == - original.get_axis_num( + assert (actual.coord_dims(coord) + == original.get_axis_num( original.coords[coord.var_name].dims)) - assert (actual.coord('distance2').attributes['foo'] == - original.coords['distance2'].attrs['foo']) - assert (actual.coord('distance').units == - cf_units.Unit(original.coords['distance'].units)) + assert (actual.coord('distance2').attributes['foo'] + == original.coords['distance2'].attrs['foo']) + assert (actual.coord('distance').units + == cf_units.Unit(original.coords['distance'].units)) assert actual.attributes['baz'] == original.attrs['baz'] assert actual.standard_name == original.attrs['standard_name'] @@ -4022,14 +4022,14 @@ def test_to_and_from_iris_dask(self): assert coord.var_name == original_coord.name assert_array_equal( coord.points, CFDatetimeCoder().encode(original_coord).values) - assert (actual.coord_dims(coord) == - original.get_axis_num( + assert (actual.coord_dims(coord) + == original.get_axis_num( original.coords[coord.var_name].dims)) assert (actual.coord('distance2').attributes['foo'] == original.coords[ 'distance2'].attrs['foo']) - assert (actual.coord('distance').units == - cf_units.Unit(original.coords['distance'].units)) + assert (actual.coord('distance').units + == cf_units.Unit(original.coords['distance'].units)) assert actual.attributes['baz'] == original.attrs['baz'] assert actual.standard_name == original.attrs['standard_name'] @@ -4125,3 +4125,30 @@ def test_fallback_to_iris_AuxCoord(self, coord_values): expected = Cube(data, aux_coords_and_dims=[ (AuxCoord(coord_values, var_name='space'), 0)]) assert result == expected + + +@requires_numbagg +@pytest.mark.parametrize('dim', ['time', 'x']) +@pytest.mark.parametrize('window_type, window', [ + ['span', 5], + ['alpha', 0.5], + ['com', 0.5], + ['halflife', 5], +]) +def test_rolling_exp(da, dim, window_type, window): + da = da.isel(a=0) + da = da.where(da > 0.2) + + result = da.rolling_exp(window_type=window_type, **{dim: window}).mean() + assert isinstance(result, DataArray) + + pandas_array = da.to_pandas() + assert pandas_array.index.name == 'time' + if dim == 'x': + pandas_array = pandas_array.T + expected = ( + xr.DataArray(pandas_array.ewm(**{window_type: window}).mean()) + .transpose(*da.dims) + ) + + assert_allclose(expected.variable, result.variable) diff --git a/xarray/tests/test_dataset.py b/xarray/tests/test_dataset.py index e3a01bbd3a1..1265f6a337a 100644 --- a/xarray/tests/test_dataset.py +++ b/xarray/tests/test_dataset.py @@ -23,7 +23,7 @@ InaccessibleArray, UnexpectedDataAccess, assert_allclose, assert_array_equal, assert_equal, assert_identical, has_cftime, has_dask, raises_regex, requires_bottleneck, requires_cftime, requires_dask, - requires_scipy, source_ndarray) + requires_numbagg, requires_scipy, source_ndarray) try: import dask.array as da @@ -4828,6 +4828,13 @@ def test_rolling_wrapped_bottleneck(ds, name, center, min_periods, key): assert_equal(actual, ds['time']) +@requires_numbagg +def test_rolling_exp(ds): + + result = ds.rolling_exp(time=10, window_type='span').mean() + assert isinstance(result, Dataset) + + @pytest.mark.parametrize('center', (True, False)) @pytest.mark.parametrize('min_periods', (None, 1, 2, 3)) @pytest.mark.parametrize('window', (1, 2, 3, 4)) diff --git a/xarray/util/print_versions.py b/xarray/util/print_versions.py index 50389df85cb..c34faa7487b 100755 --- a/xarray/util/print_versions.py +++ b/xarray/util/print_versions.py @@ -108,6 +108,7 @@ def show_versions(as_json=False): ("matplotlib", lambda mod: mod.__version__), ("cartopy", lambda mod: mod.__version__), ("seaborn", lambda mod: mod.__version__), + ("numbagg", lambda mod: mod.__version__), # xarray setup/test ("setuptools", lambda mod: mod.__version__), ("pip", lambda mod: mod.__version__), From b054c317f86639cd3b889a96d77ddb3798f8584e Mon Sep 17 00:00:00 2001 From: David Huard Date: Mon, 24 Jun 2019 11:21:28 -0400 Subject: [PATCH 31/31] Add quantile method to GroupBy (#2828) * implement groupby.quantile + tests * added quantile method in whats-new * mark additional test as xfail. * lint fix * simpler version of groupby.quantile * added quantile methods to api.rst * included DEFAULT_DIMS handling in quantile method * clarified groupby tests * added test with more typical use case * pep8 * removed failing test --- doc/api.rst | 3 +- doc/whats-new.rst | 5 +-- xarray/core/groupby.py | 58 ++++++++++++++++++++++++++++++++++ xarray/tests/test_groupby.py | 60 ++++++++++++++++++++++++++++++++++++ 4 files changed, 123 insertions(+), 3 deletions(-) diff --git a/doc/api.rst b/doc/api.rst index 811e3241438..258d1748c1b 100644 --- a/doc/api.rst +++ b/doc/api.rst @@ -190,6 +190,7 @@ Computation :py:attr:`~core.groupby.DatasetGroupBy.last` :py:attr:`~core.groupby.DatasetGroupBy.fillna` :py:attr:`~core.groupby.DatasetGroupBy.where` +:py:attr:`~core.groupby.DatasetGroupBy.quantile` Reshaping and reorganizing -------------------------- @@ -362,7 +363,7 @@ Computation :py:attr:`~core.groupby.DataArrayGroupBy.last` :py:attr:`~core.groupby.DataArrayGroupBy.fillna` :py:attr:`~core.groupby.DataArrayGroupBy.where` - +:py:attr:`~core.groupby.DataArrayGroupBy.quantile` Reshaping and reorganizing -------------------------- diff --git a/doc/whats-new.rst b/doc/whats-new.rst index b48614bd0ff..0275630f4c2 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -21,7 +21,8 @@ v0.12.2 (unreleased) Enhancements ~~~~~~~~~~~~ - +- New :py:meth:`~xarray.GroupBy.quantile` method. (:issue:`3018`) + By `David Huard `_. - Add ``keepdims`` argument for reduce operations (:issue:`2170`) By `Scott Wales `_. - netCDF chunksizes are now only dropped when original_shape is different, @@ -90,7 +91,7 @@ Bug fixes By `Maximilian Roos `_. - Fixed performance issues with cftime installed (:issue:`3000`) By `0x0L `_. -- Replace incorrect usages of `message` in pytest assertions +- Replace incorrect usages of `message` in pytest assertions with `match` (:issue:`3011`) By `Maximilian Roos `_. - Add explicit pytest markers, now required by pytest diff --git a/xarray/core/groupby.py b/xarray/core/groupby.py index d7dcb5b0426..108e85f729f 100644 --- a/xarray/core/groupby.py +++ b/xarray/core/groupby.py @@ -595,6 +595,64 @@ def _combine(self, applied, restore_coord_dims=False, shortcut=False): combined = self._maybe_unstack(combined) return combined + def quantile(self, q, dim=None, interpolation='linear', keep_attrs=None): + """Compute the qth quantile over each array in the groups and + concatenate them together into a new array. + + Parameters + ---------- + q : float in range of [0,1] (or sequence of floats) + Quantile to compute, which must be between 0 and 1 + inclusive. + dim : str or sequence of str, optional + Dimension(s) over which to apply quantile. + Defaults to the grouped dimension. + interpolation : {'linear', 'lower', 'higher', 'midpoint', 'nearest'} + This optional parameter specifies the interpolation method to + use when the desired quantile lies between two data points + ``i < j``: + * linear: ``i + (j - i) * fraction``, where ``fraction`` is + the fractional part of the index surrounded by ``i`` and + ``j``. + * lower: ``i``. + * higher: ``j``. + * nearest: ``i`` or ``j``, whichever is nearest. + * midpoint: ``(i + j) / 2``. + + Returns + ------- + quantiles : Variable + If `q` is a single quantile, then the result + is a scalar. If multiple percentiles are given, first axis of + the result corresponds to the quantile and a quantile dimension + is added to the return array. The other dimensions are the + dimensions that remain after the reduction of the array. + + See Also + -------- + numpy.nanpercentile, pandas.Series.quantile, Dataset.quantile, + DataArray.quantile + """ + if dim == DEFAULT_DIMS: + dim = ALL_DIMS + # TODO change this to dim = self._group_dim after + # the deprecation process + if self._obj.ndim > 1: + warnings.warn( + "Default reduction dimension will be changed to the " + "grouped dimension in a future version of xarray. To " + "silence this warning, pass dim=xarray.ALL_DIMS " + "explicitly.", + FutureWarning, stacklevel=2) + + out = self.apply(self._obj.__class__.quantile, shortcut=False, + q=q, dim=dim, interpolation=interpolation, + keep_attrs=keep_attrs) + + if np.asarray(q, dtype=np.float64).ndim == 0: + out = out.drop('quantile') + return out + def reduce(self, func, dim=None, axis=None, keep_attrs=None, shortcut=True, **kwargs): """Reduce the items in this group by applying `func` along some diff --git a/xarray/tests/test_groupby.py b/xarray/tests/test_groupby.py index b623c9bf05d..5433bd00f9d 100644 --- a/xarray/tests/test_groupby.py +++ b/xarray/tests/test_groupby.py @@ -105,4 +105,64 @@ def func(arg1, arg2, arg3=0): assert_identical(expected, actual) +def test_da_groupby_quantile(): + + array = xr.DataArray([1, 2, 3, 4, 5, 6], + [('x', [1, 1, 1, 2, 2, 2])]) + + # Scalar quantile + expected = xr.DataArray([2, 5], [('x', [1, 2])]) + actual = array.groupby('x').quantile(.5) + assert_identical(expected, actual) + + # Vector quantile + expected = xr.DataArray([[1, 3], [4, 6]], + [('x', [1, 2]), ('quantile', [0, 1])]) + actual = array.groupby('x').quantile([0, 1]) + assert_identical(expected, actual) + + # Multiple dimensions + array = xr.DataArray([[1, 11, 26], [2, 12, 22], [3, 13, 23], + [4, 16, 24], [5, 15, 25]], + [('x', [1, 1, 1, 2, 2],), + ('y', [0, 0, 1])]) + + actual_x = array.groupby('x').quantile(0) + expected_x = xr.DataArray([1, 4], + [('x', [1, 2]), ]) + assert_identical(expected_x, actual_x) + + actual_y = array.groupby('y').quantile(0) + expected_y = xr.DataArray([1, 22], + [('y', [0, 1]), ]) + assert_identical(expected_y, actual_y) + + actual_xx = array.groupby('x').quantile(0, dim='x') + expected_xx = xr.DataArray([[1, 11, 22], [4, 15, 24]], + [('x', [1, 2]), ('y', [0, 0, 1])]) + assert_identical(expected_xx, actual_xx) + + actual_yy = array.groupby('y').quantile(0, dim='y') + expected_yy = xr.DataArray([[1, 26], [2, 22], [3, 23], [4, 24], [5, 25]], + [('x', [1, 1, 1, 2, 2]), ('y', [0, 1])]) + assert_identical(expected_yy, actual_yy) + + times = pd.date_range('2000-01-01', periods=365) + x = [0, 1] + foo = xr.DataArray(np.reshape(np.arange(365 * 2), (365, 2)), + coords=dict(time=times, x=x), dims=('time', 'x')) + g = foo.groupby(foo.time.dt.month) + + actual = g.quantile(0) + expected = xr.DataArray([0., 62., 120., 182., 242., 304., + 364., 426., 488., 548., 610., 670.], + [('month', np.arange(1, 13))]) + assert_identical(expected, actual) + + actual = g.quantile(0, dim='time')[:2] + expected = xr.DataArray([[0., 1], [62., 63]], + [('month', [1, 2]), ('x', [0, 1])]) + assert_identical(expected, actual) + + # TODO: move other groupby tests from test_dataset and test_dataarray over here