From 70e661dc3d3ff952bc4a1f6453a7e4f7067dbf12 Mon Sep 17 00:00:00 2001 From: Stephan Hoyer Date: Mon, 9 Sep 2024 18:34:40 -0700 Subject: [PATCH] fix repr for inherited dimensions --- xarray/core/formatting.py | 26 ++++++++----- xarray/tests/test_datatree.py | 67 +++++++++++++++++++++++++++++---- xarray/tests/test_formatting.py | 2 +- 3 files changed, 78 insertions(+), 17 deletions(-) diff --git a/xarray/core/formatting.py b/xarray/core/formatting.py index e4b9d928d5b..3f42d4828a3 100644 --- a/xarray/core/formatting.py +++ b/xarray/core/formatting.py @@ -7,12 +7,12 @@ import functools import math from collections import ChainMap, defaultdict -from collections.abc import Collection, Hashable, Sequence +from collections.abc import Collection, Hashable, Mapping, Sequence from datetime import datetime, timedelta from itertools import chain, zip_longest from reprlib import recursive_repr from textwrap import dedent -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any import numpy as np import pandas as pd @@ -587,8 +587,10 @@ def _element_formatter( return "".join(out) -def dim_summary_limited(obj, col_width: int, max_rows: int | None = None) -> str: - elements = [f"{k}: {v}" for k, v in obj.sizes.items()] +def dim_summary_limited( + sizes: Mapping[Any, int], col_width: int, max_rows: int | None = None +) -> str: + elements = [f"{k}: {v}" for k, v in sizes.items()] return _element_formatter(elements, col_width, max_rows) @@ -692,7 +694,7 @@ def array_repr(arr): data_repr = inline_variable_array_repr(arr.variable, OPTIONS["display_width"]) start = f" Size: {nbytes_str}", @@ -737,7 +739,9 @@ def dataset_repr(ds): max_rows = OPTIONS["display_max_rows"] dims_start = pretty_print("Dimensions:", col_width) - dims_values = dim_summary_limited(ds, col_width=col_width + 1, max_rows=max_rows) + dims_values = dim_summary_limited( + ds.sizes, col_width=col_width + 1, max_rows=max_rows + ) summary.append(f"{dims_start}({dims_values})") if ds.coords: @@ -772,7 +776,9 @@ def dims_and_coords_repr(ds) -> str: max_rows = OPTIONS["display_max_rows"] dims_start = pretty_print("Dimensions:", col_width) - dims_values = dim_summary_limited(ds, col_width=col_width + 1, max_rows=max_rows) + dims_values = dim_summary_limited( + ds.sizes, col_width=col_width + 1, max_rows=max_rows + ) summary.append(f"{dims_start}({dims_values})") if ds.coords: @@ -1083,11 +1089,13 @@ def _datatree_node_repr(node: DataTree, show_inherited: bool) -> str: or node._data_variables ) + dim_sizes = node.sizes if show_inherited else node._node_dims + if show_dims: # Includes inherited dimensions. dims_start = pretty_print("Dimensions:", col_width) dims_values = dim_summary_limited( - node, col_width=col_width + 1, max_rows=max_rows + dim_sizes, col_width=col_width + 1, max_rows=max_rows ) summary.append(f"{dims_start}({dims_values})") @@ -1101,7 +1109,7 @@ def _datatree_node_repr(node: DataTree, show_inherited: bool) -> str: if show_dims: unindexed_dims_str = unindexed_dims_repr( - node.dims, node.coords, max_rows=max_rows + dim_sizes, node.coords, max_rows=max_rows ) if unindexed_dims_str: summary.append(unindexed_dims_str) diff --git a/xarray/tests/test_datatree.py b/xarray/tests/test_datatree.py index ba3041f271f..ac074b90d62 100644 --- a/xarray/tests/test_datatree.py +++ b/xarray/tests/test_datatree.py @@ -773,7 +773,8 @@ def test_operation_with_attrs_but_no_data(self): class TestRepr: - def test_repr(self): + + def test_repr_four_nodes(self): dt = DataTree.from_dict( { "/": xr.Dataset( @@ -797,14 +798,13 @@ def test_repr(self): │ Data variables: │ e (x) float64 16B 1.0 2.0 └── Group: /b - │ Dimensions: (x: 2, y: 1) + │ Dimensions: (y: 1) │ Dimensions without coordinates: y │ Data variables: │ f (y) float64 8B 3.0 ├── Group: /b/c └── Group: /b/d - Dimensions: (x: 2, y: 1) - Dimensions without coordinates: y + Dimensions: () Data variables: g float64 8B 4.0 """ @@ -824,15 +824,29 @@ def test_repr(self): │ f (y) float64 8B 3.0 ├── Group: /b/c └── Group: /b/d - Dimensions: (x: 2, y: 1) - Dimensions without coordinates: y + Dimensions: () Data variables: g float64 8B 4.0 """ ).strip() assert result == expected - def test_repr2(self): + result = repr(dt.b.d) + expected = dedent( + """ + + Group: /b/d + Dimensions: (x: 2, y: 1) + Inherited coordinates: + * x (x) float64 16B 2.0 3.0 + Dimensions without coordinates: y + Data variables: + g float64 8B 4.0 + """ + ).strip() + assert result == expected + + def test_repr_two_children(self): tree = DataTree.from_dict( { "/": Dataset(coords={"x": [1.0]}), @@ -884,6 +898,45 @@ def test_repr2(self): ).strip() assert result == expected + def test_repr_inherited_dims(self): + tree = DataTree.from_dict( + { + "/": Dataset({"foo": ("x", [1.0])}), + "/child": Dataset({"bar": ("y", [2.0])}), + } + ) + + result = repr(tree) + expected = dedent( + """ + + Group: / + │ Dimensions: (x: 1) + │ Dimensions without coordinates: x + │ Data variables: + │ foo (x) float64 8B 1.0 + └── Group: /child + Dimensions: (y: 1) + Dimensions without coordinates: y + Data variables: + bar (y) float64 8B 2.0 + """ + ).strip() + assert result == expected + + result = repr(tree["child"]) + expected = dedent( + """ + + Group: /child + Dimensions: (x: 1, y: 1) + Dimensions without coordinates: x, y + Data variables: + bar (y) float64 8B 2.0 + """ + ).strip() + assert result == expected + def _exact_match(message: str) -> str: return re.escape(dedent(message).strip()) diff --git a/xarray/tests/test_formatting.py b/xarray/tests/test_formatting.py index 696c849cea1..039bbfb4606 100644 --- a/xarray/tests/test_formatting.py +++ b/xarray/tests/test_formatting.py @@ -883,7 +883,7 @@ def test__mapping_repr(display_max_rows, n_vars, n_attr) -> None: col_width = formatting._calculate_col_width(ds.variables) dims_start = formatting.pretty_print("Dimensions:", col_width) dims_values = formatting.dim_summary_limited( - ds, col_width=col_width + 1, max_rows=display_max_rows + ds.sizes, col_width=col_width + 1, max_rows=display_max_rows ) expected_size = "1kB" expected = f"""\