diff --git a/xarray/core/formatting.py b/xarray/core/formatting.py index 657c9a2dbfb..e4b9d928d5b 100644 --- a/xarray/core/formatting.py +++ b/xarray/core/formatting.py @@ -6,7 +6,7 @@ import contextlib import functools import math -from collections import defaultdict +from collections import ChainMap, defaultdict from collections.abc import Collection, Hashable, Sequence from datetime import datetime, timedelta from itertools import chain, zip_longest @@ -29,6 +29,7 @@ if TYPE_CHECKING: from xarray.core.coordinates import AbstractCoordinates from xarray.core.datatree import DataTree + from xarray.core.variable import Variable UNITS = ("B", "kB", "MB", "GB", "TB", "PB", "EB", "ZB", "YB") @@ -318,7 +319,7 @@ def inline_variable_array_repr(var, max_width): def summarize_variable( name: Hashable, - var, + var: Variable, col_width: int, max_width: int | None = None, is_index: bool = False, @@ -446,6 +447,21 @@ def coords_repr(coords: AbstractCoordinates, col_width=None, max_rows=None): ) +def inherited_coords_repr(node: DataTree, col_width=None, max_rows=None): + coords = _inherited_vars(node._coord_variables) + if col_width is None: + col_width = _calculate_col_width(coords) + return _mapping_repr( + coords, + title="Inherited coordinates", + summarizer=summarize_variable, + expand_option_name="display_expand_coords", + col_width=col_width, + indexes=node._indexes, + max_rows=max_rows, + ) + + def inline_index_repr(index: pd.Index, max_width=None): if hasattr(index, "_repr_inline_"): repr_ = index._repr_inline_(max_width=max_width) @@ -498,12 +514,12 @@ def filter_nondefault_indexes(indexes, filter_indexes: bool): } -def indexes_repr(indexes, max_rows: int | None = None) -> str: +def indexes_repr(indexes, max_rows: int | None = None, title: str = "Indexes") -> str: col_width = _calculate_col_width(chain.from_iterable(indexes)) return _mapping_repr( indexes, - "Indexes", + title, summarize_index, "display_expand_indexes", col_width=col_width, @@ -1048,19 +1064,71 @@ def diff_datatree_repr(a: DataTree, b: DataTree, compat): return "\n".join(summary) -def _single_node_repr(node: DataTree) -> str: - """Information about this node, not including its relationships to other nodes.""" - if node.has_data or node.has_attrs: - # TODO: change this to inherited=False, in order to clarify what is - # inherited? https://github.com/pydata/xarray/issues/9463 - node_view = node._to_dataset_view(rebuild_dims=False, inherited=True) - ds_info = "\n" + repr(node_view) - else: - ds_info = "" - return f"Group: {node.path}{ds_info}" +def _inherited_vars(mapping: ChainMap) -> dict: + return {k: v for k, v in mapping.parents.items() if k not in mapping.maps[0]} + + +def _datatree_node_repr(node: DataTree, show_inherited: bool) -> str: + summary = [f"Group: {node.path}"] + + col_width = _calculate_col_width(node.variables) + max_rows = OPTIONS["display_max_rows"] + + inherited_coords = _inherited_vars(node._coord_variables) + # Only show dimensions if also showing a variable or coordinates section. + show_dims = ( + node._node_coord_variables + or (show_inherited and inherited_coords) + or node._data_variables + ) + + if show_dims: + # Includes inherited dimensions. + dims_start = pretty_print("Dimensions:", col_width) + dims_values = dim_summary_limited( + node, col_width=col_width + 1, max_rows=max_rows + ) + summary.append(f"{dims_start}({dims_values})") -def datatree_repr(dt: DataTree): + if node._node_coord_variables: + summary.append(coords_repr(node.coords, col_width=col_width, max_rows=max_rows)) + + if show_inherited and inherited_coords: + summary.append( + inherited_coords_repr(node, col_width=col_width, max_rows=max_rows) + ) + + if show_dims: + unindexed_dims_str = unindexed_dims_repr( + node.dims, node.coords, max_rows=max_rows + ) + if unindexed_dims_str: + summary.append(unindexed_dims_str) + + if node._data_variables: + summary.append( + data_vars_repr(node._data_variables, col_width=col_width, max_rows=max_rows) + ) + + # TODO: only show indexes defined at this node, with a separate section for + # inherited indexes (if show_inherited=True) + display_default_indexes = _get_boolean_with_default( + "display_default_indexes", False + ) + xindexes = filter_nondefault_indexes( + _get_indexes_dict(node.xindexes), not display_default_indexes + ) + if xindexes: + summary.append(indexes_repr(xindexes, max_rows=max_rows)) + + if node.attrs: + summary.append(attrs_repr(node.attrs, max_rows=max_rows)) + + return "\n".join(summary) + + +def datatree_repr(dt: DataTree) -> str: """A printable representation of the structure of this entire tree.""" renderer = RenderDataTree(dt) @@ -1068,19 +1136,21 @@ def datatree_repr(dt: DataTree): header = f"" lines = [header] + show_inherited = True for pre, fill, node in renderer: - node_repr = _single_node_repr(node) + node_repr = _datatree_node_repr(node, show_inherited=show_inherited) + show_inherited = False # only show inherited coords on the root - node_line = f"{pre}{node_repr.splitlines()[0]}" + raw_repr_lines = node_repr.splitlines() + + node_line = f"{pre}{raw_repr_lines[0]}" lines.append(node_line) - if node.has_data or node.has_attrs: - ds_repr = node_repr.splitlines()[2:] - for line in ds_repr: - if len(node.children) > 0: - lines.append(f"{fill}{renderer.style.vertical}{line}") - else: - lines.append(f"{fill}{' ' * len(renderer.style.vertical)}{line}") + for line in raw_repr_lines[1:]: + if len(node.children) > 0: + lines.append(f"{fill}{renderer.style.vertical}{line}") + else: + lines.append(f"{fill}{' ' * len(renderer.style.vertical)}{line}") return "\n".join(lines) diff --git a/xarray/tests/test_datatree.py b/xarray/tests/test_datatree.py index f1f74d240f0..cbdbd541fb0 100644 --- a/xarray/tests/test_datatree.py +++ b/xarray/tests/test_datatree.py @@ -798,20 +798,16 @@ def test_repr(self): │ e (x) float64 16B 1.0 2.0 └── Group: /b │ Dimensions: (x: 2, y: 1) - │ Coordinates: - │ * x (x) float64 16B 2.0 3.0 │ Dimensions without coordinates: y │ Data variables: │ f (y) float64 8B 3.0 ├── Group: /b/c └── Group: /b/d Dimensions: (x: 2, y: 1) - Coordinates: - * x (x) float64 16B 2.0 3.0 Dimensions without coordinates: y Data variables: g float64 8B 4.0 - """ + """ ).strip() assert result == expected @@ -821,7 +817,7 @@ def test_repr(self): Group: /b │ Dimensions: (x: 2, y: 1) - │ Coordinates: + │ Inherited coordinates: │ * x (x) float64 16B 2.0 3.0 │ Dimensions without coordinates: y │ Data variables: @@ -829,8 +825,6 @@ def test_repr(self): ├── Group: /b/c └── Group: /b/d Dimensions: (x: 2, y: 1) - Coordinates: - * x (x) float64 16B 2.0 3.0 Dimensions without coordinates: y Data variables: g float64 8B 4.0