Skip to content

Commit

Permalink
Enable additional invariant checks in xarray's test suite
Browse files Browse the repository at this point in the history
  • Loading branch information
shoyer committed Mar 31, 2019
1 parent 5308246 commit 32032e8
Show file tree
Hide file tree
Showing 2 changed files with 102 additions and 9 deletions.
81 changes: 81 additions & 0 deletions xarray/testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,3 +180,84 @@ def _assert_indexes_invariants(a):
elif isinstance(a, xr.Variable):
# no indexes
pass


def _assert_variable_invariants(a, name=None):
name_or_empty = (name,) if name is not None else ()
assert isinstance(a._dims, tuple), name_or_empty + (a._dims,)
assert len(a._dims) == len(a._data.shape), \
name_or_empty + (a._dims, a._data.shape)
assert isinstance(a._encoding, (type(None), dict)), \
name_or_empty + (a._encoding,)
assert isinstance(a._attrs, (type(None), OrderedDict)), \
name_or_empty + (a._attrs,)


def _assert_dataarray_invariants(a):
import xarray as xr

assert isinstance(a._variable, xr.Variable), a._variable
_assert_variable_invariants(a._variable)

assert isinstance(a._coords, OrderedDict), a._coords
assert all(
isinstance(v, xr.Variable) for v in a._coords.values()), a._coords
assert all(set(v.dims) <= set(a.dims) for v in a._coords.values()), \
(a.dims, {k: v.dims for k, v in a._coords.items()})
assert all(isinstance(v, xr.IndexVariable)
for (k, v) in a._coords.items()
if v.dims == (k,)), \
{k: type(v) for k, v in a._coords.items()}
for k, v in a._coords.items():
_assert_variable_invariants(v, k)

assert a._initialized is True


def _assert_dataset_invariants(a):
import xarray as xr

assert isinstance(a._variables, OrderedDict), type(a._variables)
assert all(
isinstance(v, xr.Variable) for v in a._variables.values()), \
a._variables
for k, v in a._variables.items():
_assert_variable_invariants(v, k)

assert isinstance(a._coord_names, set), a._coord_names
assert a._coord_names <= a._variables.keys(), \
(a._coord_names, set(a._variables))

assert type(a._dims) is dict, a._dims
assert all(isinstance(v, int) for v in a._dims.values()), a._dims
var_dims = set.union(*[set(v.dims) for v in a._variables.values()])
assert a._dims.keys() == var_dims, (set(a._dims), var_dims)
assert all(a._dims[k] == v.sizes[k]
for v in a._variables.values()
for k in v.sizes), \
(a._dims, {k: v.sizes for k, v in a._variables.items()})

assert isinstance(a._encoding, (type(None), dict))
assert isinstance(a._attrs, (type(None), OrderedDict))
assert a._initialized is True


def _assert_internal_invariants(a):
"""Validate that an xarray object satisfies its own internal invariants.
This exists for the benefit of xarray's own test suite, but may be useful
in external projects if they (ill-advisedly) create objects using xarray's
private APIs.
"""
import xarray as xr
if isinstance(a, xr.Variable):
_assert_variable_invariants(a)
elif isinstance(a, xr.DataArray):
_assert_dataarray_invariants(a)
elif isinstance(a, xr.Dataset):
_assert_dataset_invariants(a)
else:
raise TypeError('{} not supported by assertion comparison'
.format(type(a)))

_assert_indexes_invariants(a)
30 changes: 21 additions & 9 deletions xarray/tests/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,19 +185,31 @@ def source_ndarray(array):
# invariants
# TODO: add more invariant checks.

def assert_equal(a, b):
def assert_equal(a, b, *, check_invariants=False):
xarray.testing.assert_equal(a, b)
xarray.testing._assert_indexes_invariants(a)
xarray.testing._assert_indexes_invariants(b)
if check_invariants:
xarray.testing._assert_internal_invariants(a)
xarray.testing._assert_internal_invariants(b)
else:
xarray.testing._assert_indexes_invariants(a)
xarray.testing._assert_indexes_invariants(b)


def assert_identical(a, b):
def assert_identical(a, b, *, check_invariants=False):
xarray.testing.assert_identical(a, b)
xarray.testing._assert_indexes_invariants(a)
xarray.testing._assert_indexes_invariants(b)
if check_invariants:
xarray.testing._assert_internal_invariants(a)
xarray.testing._assert_internal_invariants(b)
else:
xarray.testing._assert_indexes_invariants(a)
xarray.testing._assert_indexes_invariants(b)


def assert_allclose(a, b, **kwargs):
def assert_allclose(a, b, *, check_invariants=False, **kwargs):
xarray.testing.assert_allclose(a, b, **kwargs)
xarray.testing._assert_indexes_invariants(a)
xarray.testing._assert_indexes_invariants(b)
if check_invariants:
xarray.testing._assert_internal_invariants(a)
xarray.testing._assert_internal_invariants(b)
else:
xarray.testing._assert_indexes_invariants(a)
xarray.testing._assert_indexes_invariants(b)

0 comments on commit 32032e8

Please sign in to comment.