Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bug fixes for DataTree indexing and aggregation #9626

Merged
merged 4 commits into from
Oct 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 13 additions & 10 deletions xarray/core/datatree.py
Original file line number Diff line number Diff line change
Expand Up @@ -1681,7 +1681,8 @@ def reduce(
numeric_only=numeric_only,
**kwargs,
)
result[node.path] = node_result
path = "/" if node is self else node.relative_to(self)
result[path] = node_result
return type(self).from_dict(result, name=self.name)

def _selective_indexing(
Expand All @@ -1706,15 +1707,17 @@ def _selective_indexing(
# Ideally, we would avoid creating such coordinates in the first
# place, but that would require implementing indexing operations at
# the Variable instead of the Dataset level.
for k in node_indexers:
if k not in node._node_coord_variables and k in node_result.coords:
# We remove all inherited coordinates. Coordinates
# corresponding to an index would be de-duplicated by
# _deduplicate_inherited_coordinates(), but indexing (e.g.,
# with a scalar) can also create scalar coordinates, which
# need to be explicitly removed.
del node_result.coords[k]
result[node.path] = node_result
if node is not self:
for k in node_indexers:
if k not in node._node_coord_variables and k in node_result.coords:
# We remove all inherited coordinates. Coordinates
# corresponding to an index would be de-duplicated by
# _deduplicate_inherited_coordinates(), but indexing (e.g.,
# with a scalar) can also create scalar coordinates, which
# need to be explicitly removed.
del node_result.coords[k]
path = "/" if node is self else node.relative_to(self)
result[path] = node_result
return type(self).from_dict(result, name=self.name)

def isel(
Expand Down
32 changes: 30 additions & 2 deletions xarray/tests/test_datatree.py
Original file line number Diff line number Diff line change
Expand Up @@ -1651,7 +1651,18 @@ def test_isel_inherited(self):
assert_equal(actual, expected)

actual = tree.isel(x=slice(None))
assert_equal(actual, tree)

# TODO: re-enable after the fix to copy() from #9628 is submitted
# actual = tree.children["child"].isel(x=slice(None))
# expected = tree.children["child"].copy()
# assert_identical(actual, expected)

actual = tree.children["child"].isel(x=0)
expected = DataTree(
dataset=xr.Dataset({"foo": 3}, coords={"x": 1}),
name="child",
)
assert_identical(actual, expected)

def test_sel(self):
tree = DataTree.from_dict(
Expand All @@ -1667,7 +1678,14 @@ def test_sel(self):
}
)
actual = tree.sel(x=2)
assert_equal(actual, expected)
assert_identical(actual, expected)

actual = tree.children["first"].sel(x=2)
expected = DataTree(
dataset=xr.Dataset({"a": 2}, coords={"x": 2}),
name="first",
)
assert_identical(actual, expected)


class TestAggregations:
Expand Down Expand Up @@ -1739,6 +1757,16 @@ def test_dim_argument(self):
):
dt.mean("invalid")

def test_subtree(self):
tree = DataTree.from_dict(
{
"/child": Dataset({"a": ("x", [1, 2])}),
}
)
expected = DataTree(dataset=Dataset({"a": 1.5}), name="child")
actual = tree.children["child"].mean()
assert_identical(expected, actual)


class TestOps:
def test_unary_op(self):
Expand Down
Loading