Skip to content

Commit

Permalink
Update DataTree repr to indicate inheritance
Browse files Browse the repository at this point in the history
  • Loading branch information
shoyer committed Sep 9, 2024
1 parent cea354f commit 67cc75f
Show file tree
Hide file tree
Showing 2 changed files with 96 additions and 32 deletions.
118 changes: 94 additions & 24 deletions xarray/core/formatting.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import contextlib
import functools
import math
from collections import defaultdict
from collections import ChainMap, defaultdict
from collections.abc import Collection, Hashable, Sequence
from datetime import datetime, timedelta
from itertools import chain, zip_longest
Expand All @@ -29,6 +29,7 @@
if TYPE_CHECKING:
from xarray.core.coordinates import AbstractCoordinates
from xarray.core.datatree import DataTree
from xarray.core.variable import Variable

UNITS = ("B", "kB", "MB", "GB", "TB", "PB", "EB", "ZB", "YB")

Expand Down Expand Up @@ -318,7 +319,7 @@ def inline_variable_array_repr(var, max_width):

def summarize_variable(
name: Hashable,
var,
var: Variable,
col_width: int,
max_width: int | None = None,
is_index: bool = False,
Expand Down Expand Up @@ -446,6 +447,21 @@ def coords_repr(coords: AbstractCoordinates, col_width=None, max_rows=None):
)


def inherited_coords_repr(node: DataTree, col_width=None, max_rows=None):
coords = _inherited_vars(node._coord_variables)
if col_width is None:
col_width = _calculate_col_width(coords)
return _mapping_repr(
coords,
title="Inherited coordinates",
summarizer=summarize_variable,
expand_option_name="display_expand_coords",
col_width=col_width,
indexes=node._indexes,
max_rows=max_rows,
)


def inline_index_repr(index: pd.Index, max_width=None):
if hasattr(index, "_repr_inline_"):
repr_ = index._repr_inline_(max_width=max_width)
Expand Down Expand Up @@ -498,12 +514,12 @@ def filter_nondefault_indexes(indexes, filter_indexes: bool):
}


def indexes_repr(indexes, max_rows: int | None = None) -> str:
def indexes_repr(indexes, max_rows: int | None = None, title: str = "Indexes") -> str:
col_width = _calculate_col_width(chain.from_iterable(indexes))

return _mapping_repr(
indexes,
"Indexes",
title,
summarize_index,
"display_expand_indexes",
col_width=col_width,
Expand Down Expand Up @@ -1048,39 +1064,93 @@ def diff_datatree_repr(a: DataTree, b: DataTree, compat):
return "\n".join(summary)


def _single_node_repr(node: DataTree) -> str:
"""Information about this node, not including its relationships to other nodes."""
if node.has_data or node.has_attrs:
# TODO: change this to inherited=False, in order to clarify what is
# inherited? https://github.com/pydata/xarray/issues/9463
node_view = node._to_dataset_view(rebuild_dims=False, inherited=True)
ds_info = "\n" + repr(node_view)
else:
ds_info = ""
return f"Group: {node.path}{ds_info}"
def _inherited_vars(mapping: ChainMap) -> dict:
return {k: v for k, v in mapping.parents.items() if k not in mapping.maps[0]}


def _datatree_node_repr(node: DataTree, show_inherited: bool) -> str:
summary = [f"Group: {node.path}"]

col_width = _calculate_col_width(node.variables)
max_rows = OPTIONS["display_max_rows"]

inherited_coords = _inherited_vars(node._coord_variables)

# Only show dimensions if also showing a variable or coordinates section.
show_dims = (
node._node_coord_variables
or (show_inherited and inherited_coords)
or node._data_variables
)

if show_dims:
# Includes inherited dimensions.
dims_start = pretty_print("Dimensions:", col_width)
dims_values = dim_summary_limited(
node, col_width=col_width + 1, max_rows=max_rows
)
summary.append(f"{dims_start}({dims_values})")

def datatree_repr(dt: DataTree):
if node._node_coord_variables:
summary.append(coords_repr(node.coords, col_width=col_width, max_rows=max_rows))

if show_inherited and inherited_coords:
summary.append(
inherited_coords_repr(node, col_width=col_width, max_rows=max_rows)
)

if show_dims:
unindexed_dims_str = unindexed_dims_repr(
node.dims, node.coords, max_rows=max_rows
)
if unindexed_dims_str:
summary.append(unindexed_dims_str)

if node._data_variables:
summary.append(
data_vars_repr(node._data_variables, col_width=col_width, max_rows=max_rows)
)

# TODO: only show indexes defined at this node, with a separate section for
# inherited indexes (if show_inherited=True)
display_default_indexes = _get_boolean_with_default(
"display_default_indexes", False
)
xindexes = filter_nondefault_indexes(
_get_indexes_dict(node.xindexes), not display_default_indexes
)
if xindexes:
summary.append(indexes_repr(xindexes, max_rows=max_rows))

if node.attrs:
summary.append(attrs_repr(node.attrs, max_rows=max_rows))

return "\n".join(summary)


def datatree_repr(dt: DataTree) -> str:
"""A printable representation of the structure of this entire tree."""
renderer = RenderDataTree(dt)

name_info = "" if dt.name is None else f" {dt.name!r}"
header = f"<xarray.DataTree{name_info}>"

lines = [header]
show_inherited = True
for pre, fill, node in renderer:
node_repr = _single_node_repr(node)
node_repr = _datatree_node_repr(node, show_inherited=show_inherited)
show_inherited = False # only show inherited coords on the root

node_line = f"{pre}{node_repr.splitlines()[0]}"
raw_repr_lines = node_repr.splitlines()

node_line = f"{pre}{raw_repr_lines[0]}"
lines.append(node_line)

if node.has_data or node.has_attrs:
ds_repr = node_repr.splitlines()[2:]
for line in ds_repr:
if len(node.children) > 0:
lines.append(f"{fill}{renderer.style.vertical}{line}")
else:
lines.append(f"{fill}{' ' * len(renderer.style.vertical)}{line}")
for line in raw_repr_lines[1:]:
if len(node.children) > 0:
lines.append(f"{fill}{renderer.style.vertical}{line}")
else:
lines.append(f"{fill}{' ' * len(renderer.style.vertical)}{line}")

return "\n".join(lines)

Expand Down
10 changes: 2 additions & 8 deletions xarray/tests/test_datatree.py
Original file line number Diff line number Diff line change
Expand Up @@ -798,20 +798,16 @@ def test_repr(self):
│ e (x) float64 16B 1.0 2.0
└── Group: /b
│ Dimensions: (x: 2, y: 1)
│ Coordinates:
│ * x (x) float64 16B 2.0 3.0
│ Dimensions without coordinates: y
│ Data variables:
│ f (y) float64 8B 3.0
├── Group: /b/c
└── Group: /b/d
Dimensions: (x: 2, y: 1)
Coordinates:
* x (x) float64 16B 2.0 3.0
Dimensions without coordinates: y
Data variables:
g float64 8B 4.0
"""
"""
).strip()
assert result == expected

Expand All @@ -821,16 +817,14 @@ def test_repr(self):
<xarray.DataTree 'b'>
Group: /b
│ Dimensions: (x: 2, y: 1)
Coordinates:
Inherited coordinates:
│ * x (x) float64 16B 2.0 3.0
│ Dimensions without coordinates: y
│ Data variables:
│ f (y) float64 8B 3.0
├── Group: /b/c
└── Group: /b/d
Dimensions: (x: 2, y: 1)
Coordinates:
* x (x) float64 16B 2.0 3.0
Dimensions without coordinates: y
Data variables:
g float64 8B 4.0
Expand Down

0 comments on commit 67cc75f

Please sign in to comment.