Skip to content

Commit

Permalink
fix repr for inherited dimensions
Browse files Browse the repository at this point in the history
  • Loading branch information
shoyer committed Sep 10, 2024
1 parent 3b42219 commit 70e661d
Show file tree
Hide file tree
Showing 3 changed files with 78 additions and 17 deletions.
26 changes: 17 additions & 9 deletions xarray/core/formatting.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,12 @@
import functools
import math
from collections import ChainMap, defaultdict
from collections.abc import Collection, Hashable, Sequence
from collections.abc import Collection, Hashable, Mapping, Sequence
from datetime import datetime, timedelta
from itertools import chain, zip_longest
from reprlib import recursive_repr
from textwrap import dedent
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Any

import numpy as np
import pandas as pd
Expand Down Expand Up @@ -587,8 +587,10 @@ def _element_formatter(
return "".join(out)


def dim_summary_limited(obj, col_width: int, max_rows: int | None = None) -> str:
elements = [f"{k}: {v}" for k, v in obj.sizes.items()]
def dim_summary_limited(
sizes: Mapping[Any, int], col_width: int, max_rows: int | None = None
) -> str:
elements = [f"{k}: {v}" for k, v in sizes.items()]
return _element_formatter(elements, col_width, max_rows)


Expand Down Expand Up @@ -692,7 +694,7 @@ def array_repr(arr):
data_repr = inline_variable_array_repr(arr.variable, OPTIONS["display_width"])

start = f"<xarray.{type(arr).__name__} {name_str}"
dims = dim_summary_limited(arr, col_width=len(start) + 1, max_rows=max_rows)
dims = dim_summary_limited(arr.sizes, col_width=len(start) + 1, max_rows=max_rows)
nbytes_str = render_human_readable_nbytes(arr.nbytes)
summary = [
f"{start}({dims})> Size: {nbytes_str}",
Expand Down Expand Up @@ -737,7 +739,9 @@ def dataset_repr(ds):
max_rows = OPTIONS["display_max_rows"]

dims_start = pretty_print("Dimensions:", col_width)
dims_values = dim_summary_limited(ds, col_width=col_width + 1, max_rows=max_rows)
dims_values = dim_summary_limited(
ds.sizes, col_width=col_width + 1, max_rows=max_rows
)
summary.append(f"{dims_start}({dims_values})")

if ds.coords:
Expand Down Expand Up @@ -772,7 +776,9 @@ def dims_and_coords_repr(ds) -> str:
max_rows = OPTIONS["display_max_rows"]

dims_start = pretty_print("Dimensions:", col_width)
dims_values = dim_summary_limited(ds, col_width=col_width + 1, max_rows=max_rows)
dims_values = dim_summary_limited(
ds.sizes, col_width=col_width + 1, max_rows=max_rows
)
summary.append(f"{dims_start}({dims_values})")

if ds.coords:
Expand Down Expand Up @@ -1083,11 +1089,13 @@ def _datatree_node_repr(node: DataTree, show_inherited: bool) -> str:
or node._data_variables
)

dim_sizes = node.sizes if show_inherited else node._node_dims

if show_dims:
# Includes inherited dimensions.
dims_start = pretty_print("Dimensions:", col_width)
dims_values = dim_summary_limited(
node, col_width=col_width + 1, max_rows=max_rows
dim_sizes, col_width=col_width + 1, max_rows=max_rows
)
summary.append(f"{dims_start}({dims_values})")

Expand All @@ -1101,7 +1109,7 @@ def _datatree_node_repr(node: DataTree, show_inherited: bool) -> str:

if show_dims:
unindexed_dims_str = unindexed_dims_repr(
node.dims, node.coords, max_rows=max_rows
dim_sizes, node.coords, max_rows=max_rows
)
if unindexed_dims_str:
summary.append(unindexed_dims_str)
Expand Down
67 changes: 60 additions & 7 deletions xarray/tests/test_datatree.py
Original file line number Diff line number Diff line change
Expand Up @@ -773,7 +773,8 @@ def test_operation_with_attrs_but_no_data(self):


class TestRepr:
def test_repr(self):

def test_repr_four_nodes(self):
dt = DataTree.from_dict(
{
"/": xr.Dataset(
Expand All @@ -797,14 +798,13 @@ def test_repr(self):
│ Data variables:
│ e (x) float64 16B 1.0 2.0
└── Group: /b
│ Dimensions: (x: 2, y: 1)
│ Dimensions: (y: 1)
│ Dimensions without coordinates: y
│ Data variables:
│ f (y) float64 8B 3.0
├── Group: /b/c
└── Group: /b/d
Dimensions: (x: 2, y: 1)
Dimensions without coordinates: y
Dimensions: ()
Data variables:
g float64 8B 4.0
"""
Expand All @@ -824,15 +824,29 @@ def test_repr(self):
│ f (y) float64 8B 3.0
├── Group: /b/c
└── Group: /b/d
Dimensions: (x: 2, y: 1)
Dimensions without coordinates: y
Dimensions: ()
Data variables:
g float64 8B 4.0
"""
).strip()
assert result == expected

def test_repr2(self):
result = repr(dt.b.d)
expected = dedent(
"""
<xarray.DataTree 'd'>
Group: /b/d
Dimensions: (x: 2, y: 1)
Inherited coordinates:
* x (x) float64 16B 2.0 3.0
Dimensions without coordinates: y
Data variables:
g float64 8B 4.0
"""
).strip()
assert result == expected

def test_repr_two_children(self):
tree = DataTree.from_dict(
{
"/": Dataset(coords={"x": [1.0]}),
Expand Down Expand Up @@ -884,6 +898,45 @@ def test_repr2(self):
).strip()
assert result == expected

def test_repr_inherited_dims(self):
tree = DataTree.from_dict(
{
"/": Dataset({"foo": ("x", [1.0])}),
"/child": Dataset({"bar": ("y", [2.0])}),
}
)

result = repr(tree)
expected = dedent(
"""
<xarray.DataTree>
Group: /
│ Dimensions: (x: 1)
│ Dimensions without coordinates: x
│ Data variables:
│ foo (x) float64 8B 1.0
└── Group: /child
Dimensions: (y: 1)
Dimensions without coordinates: y
Data variables:
bar (y) float64 8B 2.0
"""
).strip()
assert result == expected

result = repr(tree["child"])
expected = dedent(
"""
<xarray.DataTree 'child'>
Group: /child
Dimensions: (x: 1, y: 1)
Dimensions without coordinates: x, y
Data variables:
bar (y) float64 8B 2.0
"""
).strip()
assert result == expected


def _exact_match(message: str) -> str:
return re.escape(dedent(message).strip())
Expand Down
2 changes: 1 addition & 1 deletion xarray/tests/test_formatting.py
Original file line number Diff line number Diff line change
Expand Up @@ -883,7 +883,7 @@ def test__mapping_repr(display_max_rows, n_vars, n_attr) -> None:
col_width = formatting._calculate_col_width(ds.variables)
dims_start = formatting.pretty_print("Dimensions:", col_width)
dims_values = formatting.dim_summary_limited(
ds, col_width=col_width + 1, max_rows=display_max_rows
ds.sizes, col_width=col_width + 1, max_rows=display_max_rows
)
expected_size = "1kB"
expected = f"""\
Expand Down

0 comments on commit 70e661d

Please sign in to comment.