Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support alternative names for the root node in DataTree.from_dict #9638

Merged
merged 2 commits into from
Oct 17, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 21 additions & 11 deletions xarray/core/datatree.py
Original file line number Diff line number Diff line change
Expand Up @@ -1104,10 +1104,12 @@ def from_dict(
d : dict-like
A mapping from path names to xarray.Dataset or DataTree objects.

Path names are to be given as unix-like path. If path names containing more than one
part are given, new tree nodes will be constructed as necessary.
Path names are to be given as unix-like path. If path names
containing more than one part are given, new tree nodes will be
constructed as necessary.

To assign data to the root node of the tree use "/" as the path.
To assign data to the root node of the tree use "", ".", "/" or "./"
as the path.
name : Hashable | None, optional
Name for the root node of the tree. Default is None.

Expand All @@ -1119,17 +1121,26 @@ def from_dict(
-----
If your dictionary is nested you will need to flatten it before using this method.
"""

# First create the root node
# Find any values corresponding to the root
d_cast = dict(d)
root_data = d_cast.pop("/", None)
root_data = None
for key in ("", ".", "/", "./"):
if key in d_cast:
if root_data is not None:
raise ValueError(
"multiple entries found corresponding to the root node"
)
root_data = d_cast.pop(key)

# Create the root node
if isinstance(root_data, DataTree):
obj = root_data.copy()
elif root_data is None or isinstance(root_data, Dataset):
obj = cls(name=name, dataset=root_data, children=None)
else:
raise TypeError(
f'root node data (at "/") must be a Dataset or DataTree, got {type(root_data)}'
f'root node data (at "", ".", "/" or "./") must be a Dataset '
f"or DataTree, got {type(root_data)}"
)

def depth(item) -> int:
Expand All @@ -1141,11 +1152,10 @@ def depth(item) -> int:
# Sort keys by depth so as to insert nodes from root first (see GH issue #9276)
for path, data in sorted(d_cast.items(), key=depth):
# Create and set new node
node_name = NodePath(path).name
if isinstance(data, DataTree):
new_node = data.copy()
elif isinstance(data, Dataset) or data is None:
new_node = cls(name=node_name, dataset=data)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't need to set name explicitly because it is already guaranteed to be consistent (via NamedNode._post_attach).

new_node = cls(dataset=data)
else:
raise TypeError(f"invalid values: {data}")
obj._set_item(
Expand Down Expand Up @@ -1683,7 +1693,7 @@ def reduce(
numeric_only=numeric_only,
**kwargs,
)
path = "/" if node is self else node.relative_to(self)
path = node.relative_to(self)
result[path] = node_result
return type(self).from_dict(result, name=self.name)

Expand Down Expand Up @@ -1718,7 +1728,7 @@ def _selective_indexing(
# with a scalar) can also create scalar coordinates, which
# need to be explicitly removed.
del node_result.coords[k]
path = "/" if node is self else node.relative_to(self)
path = node.relative_to(self)
result[path] = node_result
return type(self).from_dict(result, name=self.name)

Expand Down
31 changes: 31 additions & 0 deletions xarray/tests/test_datatree.py
Original file line number Diff line number Diff line change
Expand Up @@ -883,6 +883,37 @@
with pytest.raises(TypeError):
DataTree.from_dict(data) # type: ignore[arg-type]

def test_relative_paths(self) -> None:
tree = DataTree.from_dict({".": None, "foo": None, "./bar": None, "x/y": None})
paths = [node.path for node in tree.subtree]
assert paths == [

Check failure on line 889 in xarray/tests/test_datatree.py

View workflow job for this annotation

GitHub Actions / ubuntu-latest py3.10 bare-minimum

TestTreeFromDict.test_relative_paths AssertionError: assert ['/', '/foo',... '/x', '/x/y'] == ['/', '/foo', '/bar', '/x/y'] At index 3 diff: '/x' != '/x/y' Left contains one more item: '/x/y' Full diff: [ '/', '/foo', '/bar', + '/x', '/x/y', ]

Check failure on line 889 in xarray/tests/test_datatree.py

View workflow job for this annotation

GitHub Actions / macos-latest py3.10

TestTreeFromDict.test_relative_paths AssertionError: assert ['/', '/foo',... '/x', '/x/y'] == ['/', '/foo', '/bar', '/x/y'] At index 3 diff: '/x' != '/x/y' Left contains one more item: '/x/y' Full diff: [ '/', '/foo', '/bar', + '/x', '/x/y', ]

Check failure on line 889 in xarray/tests/test_datatree.py

View workflow job for this annotation

GitHub Actions / ubuntu-latest py3.11 all-but-dask

TestTreeFromDict.test_relative_paths AssertionError: assert ['/', '/foo',... '/x', '/x/y'] == ['/', '/foo', '/bar', '/x/y'] At index 3 diff: '/x' != '/x/y' Left contains one more item: '/x/y' Full diff: [ '/', '/foo', '/bar', + '/x', '/x/y', ]

Check failure on line 889 in xarray/tests/test_datatree.py

View workflow job for this annotation

GitHub Actions / ubuntu-latest py3.10

TestTreeFromDict.test_relative_paths AssertionError: assert ['/', '/foo',... '/x', '/x/y'] == ['/', '/foo', '/bar', '/x/y'] At index 3 diff: '/x' != '/x/y' Left contains one more item: '/x/y' Full diff: [ '/', '/foo', '/bar', + '/x', '/x/y', ]

Check failure on line 889 in xarray/tests/test_datatree.py

View workflow job for this annotation

GitHub Actions / ubuntu-latest py3.10 min-all-deps

TestTreeFromDict.test_relative_paths AssertionError: assert ['/', '/foo',... '/x', '/x/y'] == ['/', '/foo', '/bar', '/x/y'] At index 3 diff: '/x' != '/x/y' Left contains one more item: '/x/y' Full diff: [ '/', '/foo', '/bar', + '/x', '/x/y', ]

Check failure on line 889 in xarray/tests/test_datatree.py

View workflow job for this annotation

GitHub Actions / ubuntu-latest py3.12 all-but-numba

TestTreeFromDict.test_relative_paths AssertionError: assert ['/', '/foo',... '/x', '/x/y'] == ['/', '/foo', '/bar', '/x/y'] At index 3 diff: '/x' != '/x/y' Left contains one more item: '/x/y' Full diff: [ '/', '/foo', '/bar', + '/x', '/x/y', ]

Check failure on line 889 in xarray/tests/test_datatree.py

View workflow job for this annotation

GitHub Actions / ubuntu-latest py3.12

TestTreeFromDict.test_relative_paths AssertionError: assert ['/', '/foo',... '/x', '/x/y'] == ['/', '/foo', '/bar', '/x/y'] At index 3 diff: '/x' != '/x/y' Left contains one more item: '/x/y' Full diff: [ '/', '/foo', '/bar', + '/x', '/x/y', ]

Check failure on line 889 in xarray/tests/test_datatree.py

View workflow job for this annotation

GitHub Actions / macos-latest py3.12

TestTreeFromDict.test_relative_paths AssertionError: assert ['/', '/foo',... '/x', '/x/y'] == ['/', '/foo', '/bar', '/x/y'] At index 3 diff: '/x' != '/x/y' Left contains one more item: '/x/y' Full diff: [ '/', '/foo', '/bar', + '/x', '/x/y', ]
"/",
"/foo",
"/bar",
"/x/y",
]

def test_root_keys(self):
ds = Dataset({"x": 1})
expected = DataTree(dataset=ds)

actual = DataTree.from_dict({"": ds})
assert_identical(actual, expected)

actual = DataTree.from_dict({".": ds})
assert_identical(actual, expected)

actual = DataTree.from_dict({"/": ds})
assert_identical(actual, expected)

actual = DataTree.from_dict({"./": ds})
assert_identical(actual, expected)

with pytest.raises(
ValueError, match="multiple entries found corresponding to the root node"
):
DataTree.from_dict({"": ds, "/": ds})


class TestDatasetView:
def test_view_contents(self) -> None:
Expand Down
Loading