Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

More consistency checks #2859

Merged
merged 7 commits into from
Jun 18, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion xarray/core/coordinates.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,7 @@ def _update_coords(self, coords):

self._data._variables = variables
self._data._coord_names.update(new_coord_names)
self._data._dims = dict(dims)
self._data._dims = dims
self._data._indexes = None

def __delitem__(self, key):
Expand Down
5 changes: 3 additions & 2 deletions xarray/core/dataarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import sys
import warnings
from collections import OrderedDict
from typing import Any

import numpy as np
import pandas as pd
Expand Down Expand Up @@ -67,7 +68,7 @@ def _infer_coords_and_dims(shape, coords, dims):
for dim, coord in zip(dims, coords):
var = as_variable(coord, name=dim)
var.dims = (dim,)
new_coords[dim] = var
new_coords[dim] = var.to_index_variable()

sizes = dict(zip(dims, shape))
for k, v in new_coords.items():
Expand Down Expand Up @@ -1442,7 +1443,7 @@ def transpose(self, *dims, transpose_coords=None) -> 'DataArray':

variable = self.variable.transpose(*dims)
if transpose_coords:
coords = {}
coords = OrderedDict() # type: OrderedDict[Any, Variable]
for name, coord in self.coords.items():
coord_dims = tuple(dim for dim in dims if dim in coord.dims)
coords[name] = coord.variable.transpose(*coord_dims)
Expand Down
30 changes: 18 additions & 12 deletions xarray/core/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ def calculate_dimensions(variables):
Returns dictionary mapping from dimension names to sizes. Raises ValueError
if any of the dimension sizes conflict.
"""
dims = OrderedDict()
dims = {}
last_used = {}
scalar_vars = set(k for k, v in variables.items() if not v.dims)
for k, var in variables.items():
Expand Down Expand Up @@ -692,7 +692,7 @@ def _construct_direct(cls, variables, coord_names, dims, attrs=None,

@classmethod
def _from_vars_and_coord_names(cls, variables, coord_names, attrs=None):
dims = dict(calculate_dimensions(variables))
dims = calculate_dimensions(variables)
return cls._construct_direct(variables, coord_names, dims, attrs)

# TODO(shoyer): renable type checking on this signature when pytype has a
Expand Down Expand Up @@ -753,18 +753,20 @@ def _replace_with_new_dims( # type: ignore
coord_names: set = None,
attrs: 'Optional[OrderedDict]' = __default,
indexes: 'Optional[OrderedDict[Any, pd.Index]]' = __default,
encoding: Optional[dict] = __default,
inplace: bool = False,
) -> T:
"""Replace variables with recalculated dimensions."""
dims = dict(calculate_dimensions(variables))
dims = calculate_dimensions(variables)
return self._replace(
variables, coord_names, dims, attrs, indexes, inplace=inplace)
variables, coord_names, dims, attrs, indexes, encoding,
inplace=inplace)

def _replace_vars_and_dims( # type: ignore
self: T,
variables: 'OrderedDict[Any, Variable]' = None,
coord_names: set = None,
dims: 'OrderedDict[Any, int]' = None,
dims: Dict[Any, int] = None,
attrs: 'Optional[OrderedDict]' = __default,
inplace: bool = False,
) -> T:
Expand Down Expand Up @@ -1080,6 +1082,7 @@ def __delitem__(self, key):
"""
del self._variables[key]
self._coord_names.discard(key)
self._dims = calculate_dimensions(self._variables)

# mutable objects should not be hashable
# https://github.com/python/mypy/issues/4266
Expand Down Expand Up @@ -2469,7 +2472,7 @@ def expand_dims(self, dim=None, axis=None, **dim_kwargs):
else:
# If dims includes a label of a non-dimension coordinate,
# it will be promoted to a 1D coordinate with a single value.
variables[k] = v.set_dims(k)
variables[k] = v.set_dims(k).to_index_variable()

new_dims = self._dims.copy()
new_dims.update(dim)
Expand Down Expand Up @@ -3556,12 +3559,15 @@ def from_dict(cls, d):
def _unary_op(f, keep_attrs=False):
@functools.wraps(f)
def func(self, *args, **kwargs):
ds = self.coords.to_dataset()
for k in self.data_vars:
ds._variables[k] = f(self._variables[k], *args, **kwargs)
if keep_attrs:
ds._attrs = self._attrs
return ds
variables = OrderedDict()
for k, v in self._variables.items():
if k in self._coord_names:
variables[k] = v
else:
variables[k] = f(v, *args, **kwargs)
attrs = self._attrs if keep_attrs else None
return self._replace_with_new_dims(
variables, attrs=attrs, encoding=None)

return func

Expand Down
2 changes: 1 addition & 1 deletion xarray/core/merge.py
Original file line number Diff line number Diff line change
Expand Up @@ -473,7 +473,7 @@ def merge_core(objs,
'coordinates or not in the merged result: %s'
% ambiguous_coords)

return variables, coord_names, dict(dims)
return variables, coord_names, dims


def merge(objects, compat='no_conflicts', join='outer', fill_value=dtypes.NA):
Expand Down
138 changes: 109 additions & 29 deletions xarray/testing.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,15 @@
"""Testing functions exposed to the user API"""
from collections import OrderedDict
from typing import Hashable, Union

import numpy as np
import pandas as pd

from xarray.core import duck_array_ops, formatting
from xarray.core import duck_array_ops
from xarray.core import formatting
from xarray.core.dataarray import DataArray
from xarray.core.dataset import Dataset
from xarray.core.variable import IndexVariable, Variable
from xarray.core.indexes import default_indexes


Expand Down Expand Up @@ -48,12 +53,11 @@ def assert_equal(a, b):
assert_identical, assert_allclose, Dataset.equals, DataArray.equals,
numpy.testing.assert_array_equal
"""
import xarray as xr
__tracebackhide__ = True # noqa: F841
assert type(a) == type(b) # noqa
if isinstance(a, (xr.Variable, xr.DataArray)):
if isinstance(a, (Variable, DataArray)):
assert a.equals(b), formatting.diff_array_repr(a, b, 'equals')
elif isinstance(a, xr.Dataset):
elif isinstance(a, Dataset):
assert a.equals(b), formatting.diff_dataset_repr(a, b, 'equals')
else:
raise TypeError('{} not supported by assertion comparison'
Expand All @@ -77,15 +81,14 @@ def assert_identical(a, b):
--------
assert_equal, assert_allclose, Dataset.equals, DataArray.equals
"""
import xarray as xr
__tracebackhide__ = True # noqa: F841
assert type(a) == type(b) # noqa
if isinstance(a, xr.Variable):
if isinstance(a, Variable):
assert a.identical(b), formatting.diff_array_repr(a, b, 'identical')
elif isinstance(a, xr.DataArray):
elif isinstance(a, DataArray):
assert a.name == b.name
assert a.identical(b), formatting.diff_array_repr(a, b, 'identical')
elif isinstance(a, (xr.Dataset, xr.Variable)):
elif isinstance(a, (Dataset, Variable)):
assert a.identical(b), formatting.diff_dataset_repr(a, b, 'identical')
else:
raise TypeError('{} not supported by assertion comparison'
Expand Down Expand Up @@ -117,15 +120,14 @@ def assert_allclose(a, b, rtol=1e-05, atol=1e-08, decode_bytes=True):
--------
assert_identical, assert_equal, numpy.testing.assert_allclose
"""
import xarray as xr
__tracebackhide__ = True # noqa: F841
assert type(a) == type(b) # noqa
kwargs = dict(rtol=rtol, atol=atol, decode_bytes=decode_bytes)
if isinstance(a, xr.Variable):
if isinstance(a, Variable):
assert a.dims == b.dims
allclose = _data_allclose_or_equiv(a.values, b.values, **kwargs)
assert allclose, '{}\n{}'.format(a.values, b.values)
elif isinstance(a, xr.DataArray):
elif isinstance(a, DataArray):
assert_allclose(a.variable, b.variable, **kwargs)
assert set(a.coords) == set(b.coords)
for v in a.coords.variables:
Expand All @@ -135,7 +137,7 @@ def assert_allclose(a, b, rtol=1e-05, atol=1e-08, decode_bytes=True):
b.coords[v].values, **kwargs)
assert allclose, '{}\n{}'.format(a.coords[v].values,
b.coords[v].values)
elif isinstance(a, xr.Dataset):
elif isinstance(a, Dataset):
assert set(a.data_vars) == set(b.data_vars)
assert set(a.coords) == set(b.coords)
for k in list(a.variables) + list(a.coords):
Expand All @@ -147,14 +149,12 @@ def assert_allclose(a, b, rtol=1e-05, atol=1e-08, decode_bytes=True):


def _assert_indexes_invariants_checks(indexes, possible_coord_variables, dims):
import xarray as xr

assert isinstance(indexes, OrderedDict), indexes
assert all(isinstance(v, pd.Index) for v in indexes.values()), \
{k: type(v) for k, v in indexes.items()}

index_vars = {k for k, v in possible_coord_variables.items()
if isinstance(v, xr.IndexVariable)}
if isinstance(v, IndexVariable)}
assert indexes.keys() <= index_vars, (set(indexes), index_vars)

# Note: when we support non-default indexes, these checks should be opt-in
Expand All @@ -166,17 +166,97 @@ def _assert_indexes_invariants_checks(indexes, possible_coord_variables, dims):
(indexes, defaults)


def _assert_indexes_invariants(a):
"""Separate helper function for checking indexes invariants only."""
import xarray as xr

if isinstance(a, xr.DataArray):
if a._indexes is not None:
_assert_indexes_invariants_checks(a._indexes, a._coords, a.dims)
elif isinstance(a, xr.Dataset):
if a._indexes is not None:
_assert_indexes_invariants_checks(
a._indexes, a._variables, a._dims)
elif isinstance(a, xr.Variable):
# no indexes
pass
def _assert_variable_invariants(var: Variable, name: Hashable = None):
if name is None:
name_or_empty = () # type: tuple
else:
name_or_empty = (name,)
assert isinstance(var._dims, tuple), name_or_empty + (var._dims,)
assert len(var._dims) == len(var._data.shape), \
name_or_empty + (var._dims, var._data.shape)
assert isinstance(var._encoding, (type(None), dict)), \
name_or_empty + (var._encoding,)
assert isinstance(var._attrs, (type(None), OrderedDict)), \
name_or_empty + (var._attrs,)


def _assert_dataarray_invariants(da: DataArray):
assert isinstance(da._variable, Variable), da._variable
_assert_variable_invariants(da._variable)

assert isinstance(da._coords, OrderedDict), da._coords
assert all(
isinstance(v, Variable) for v in da._coords.values()), da._coords
assert all(set(v.dims) <= set(da.dims) for v in da._coords.values()), \
(da.dims, {k: v.dims for k, v in da._coords.items()})
assert all(isinstance(v, IndexVariable)
for (k, v) in da._coords.items()
if v.dims == (k,)), \
{k: type(v) for k, v in da._coords.items()}
for k, v in da._coords.items():
_assert_variable_invariants(v, k)

if da._indexes is not None:
_assert_indexes_invariants_checks(da._indexes, da._coords, da.dims)

assert da._initialized is True


def _assert_dataset_invariants(ds: Dataset):
assert isinstance(ds._variables, OrderedDict), type(ds._variables)
assert all(
isinstance(v, Variable) for v in ds._variables.values()), \
ds._variables
for k, v in ds._variables.items():
_assert_variable_invariants(v, k)

assert isinstance(ds._coord_names, set), ds._coord_names
assert ds._coord_names <= ds._variables.keys(), \
(ds._coord_names, set(ds._variables))

assert type(ds._dims) is dict, ds._dims
assert all(isinstance(v, int) for v in ds._dims.values()), ds._dims
var_dims = set() # type: set
for v in ds._variables.values():
var_dims.update(v.dims)
assert ds._dims.keys() == var_dims, (set(ds._dims), var_dims)
assert all(ds._dims[k] == v.sizes[k]
for v in ds._variables.values()
for k in v.sizes), \
(ds._dims, {k: v.sizes for k, v in ds._variables.items()})
assert all(isinstance(v, IndexVariable)
for (k, v) in ds._variables.items()
if v.dims == (k,)), \
{k: type(v) for k, v in ds._variables.items() if v.dims == (k,)}
assert all(v.dims == (k,)
for (k, v) in ds._variables.items()
if k in ds._dims), \
{k: v.dims for k, v in ds._variables.items() if k in ds._dims}

if ds._indexes is not None:
_assert_indexes_invariants_checks(ds._indexes, ds._variables, ds._dims)

assert isinstance(ds._encoding, (type(None), dict))
assert isinstance(ds._attrs, (type(None), OrderedDict))
assert ds._initialized is True


def _assert_internal_invariants(
xarray_obj: Union[DataArray, Dataset, Variable],
):
"""Validate that an xarray object satisfies its own internal invariants.

This exists for the benefit of xarray's own test suite, but may be useful
in external projects if they (ill-advisedly) create objects using xarray's
private APIs.
"""
if isinstance(xarray_obj, Variable):
_assert_variable_invariants(xarray_obj)
elif isinstance(xarray_obj, DataArray):
_assert_dataarray_invariants(xarray_obj)
elif isinstance(xarray_obj, Dataset):
_assert_dataset_invariants(xarray_obj)
else:
raise TypeError(
'{} is not a supported type for xarray invariant checks'
.format(type(xarray_obj)))
13 changes: 6 additions & 7 deletions xarray/tests/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,21 +168,20 @@ def source_ndarray(array):

# Internal versions of xarray's test functions that validate additional
# invariants
# TODO: add more invariant checks.

def assert_equal(a, b):
xarray.testing.assert_equal(a, b)
xarray.testing._assert_indexes_invariants(a)
xarray.testing._assert_indexes_invariants(b)
xarray.testing._assert_internal_invariants(a)
xarray.testing._assert_internal_invariants(b)


def assert_identical(a, b):
xarray.testing.assert_identical(a, b)
xarray.testing._assert_indexes_invariants(a)
xarray.testing._assert_indexes_invariants(b)
xarray.testing._assert_internal_invariants(a)
xarray.testing._assert_internal_invariants(b)


def assert_allclose(a, b, **kwargs):
xarray.testing.assert_allclose(a, b, **kwargs)
xarray.testing._assert_indexes_invariants(a)
xarray.testing._assert_indexes_invariants(b)
xarray.testing._assert_internal_invariants(a)
xarray.testing._assert_internal_invariants(b)
5 changes: 5 additions & 0 deletions xarray/tests/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -2752,6 +2752,11 @@ def test_delitem(self):
assert set(data.variables) == all_items - set(['var1', 'numbers'])
assert 'numbers' not in data.coords

expected = Dataset()
actual = Dataset({'y': ('x', [1, 2])})
del actual['y']
assert_identical(expected, actual)

def test_squeeze(self):
data = Dataset({'foo': (['x', 'y', 'z'], [[[1], [2]]])})
for args in [[], [['x']], [['x', 'z']]]:
Expand Down