diff --git a/doc/whats-new.rst b/doc/whats-new.rst index c969453b108..f1137b7b2a2 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -68,6 +68,8 @@ Bug fixes By `Alessandro Amici `_ - Limit number of data rows when printing large datasets. (:issue:`4736`, :pull:`4750`). By `Jimmy Westling `_. - Add ``missing_dims`` parameter to transpose (:issue:`4647`, :pull:`4767`). By `Daniel Mesejo `_. +- Resolve intervals before appending other metadata to labels when plotting (:issue:`4322`, :pull:`4794`). + By `Justus Magin `_. Documentation ~~~~~~~~~~~~~ diff --git a/xarray/plot/facetgrid.py b/xarray/plot/facetgrid.py index 8ed8815a060..58b38251352 100644 --- a/xarray/plot/facetgrid.py +++ b/xarray/plot/facetgrid.py @@ -306,9 +306,11 @@ def map_dataarray_line( ) self._mappables.append(mappable) - _, _, hueplt, xlabel, ylabel, huelabel = _infer_line_data( + xplt, yplt, hueplt, huelabel = _infer_line_data( darray=self.data.loc[self.name_dicts.flat[0]], x=x, y=y, hue=hue ) + xlabel = label_from_attrs(xplt) + ylabel = label_from_attrs(yplt) self._hue_var = hueplt self._hue_label = huelabel diff --git a/xarray/plot/plot.py b/xarray/plot/plot.py index 2f10240e1b7..8a57e17e5e8 100644 --- a/xarray/plot/plot.py +++ b/xarray/plot/plot.py @@ -107,10 +107,7 @@ def _infer_line_data(darray, x, y, hue): huelabel = label_from_attrs(darray[huename]) hueplt = darray[huename] - xlabel = label_from_attrs(xplt) - ylabel = label_from_attrs(yplt) - - return xplt, yplt, hueplt, xlabel, ylabel, huelabel + return xplt, yplt, hueplt, huelabel def plot( @@ -292,12 +289,14 @@ def line( assert "args" not in kwargs ax = get_axis(figsize, size, aspect, ax) - xplt, yplt, hueplt, xlabel, ylabel, hue_label = _infer_line_data(darray, x, y, hue) + xplt, yplt, hueplt, hue_label = _infer_line_data(darray, x, y, hue) # Remove pd.Intervals if contained in xplt.values and/or yplt.values. - xplt_val, yplt_val, xlabel, ylabel, kwargs = _resolve_intervals_1dplot( - xplt.values, yplt.values, xlabel, ylabel, kwargs + xplt_val, yplt_val, x_suffix, y_suffix, kwargs = _resolve_intervals_1dplot( + xplt.values, yplt.values, kwargs ) + xlabel = label_from_attrs(xplt, extra=x_suffix) + ylabel = label_from_attrs(yplt, extra=y_suffix) _ensure_plottable(xplt_val, yplt_val) diff --git a/xarray/plot/utils.py b/xarray/plot/utils.py index 3eca90a1dfe..16c67e154fc 100644 --- a/xarray/plot/utils.py +++ b/xarray/plot/utils.py @@ -503,12 +503,14 @@ def _interval_to_double_bound_points(xarray, yarray): return xarray, yarray -def _resolve_intervals_1dplot(xval, yval, xlabel, ylabel, kwargs): +def _resolve_intervals_1dplot(xval, yval, kwargs): """ Helper function to replace the values of x and/or y coordinate arrays containing pd.Interval with their mid-points or - for step plots - double points which double the length. """ + x_suffix = "" + y_suffix = "" # Is it a step plot? (see matplotlib.Axes.step) if kwargs.get("drawstyle", "").startswith("steps-"): @@ -534,13 +536,13 @@ def _resolve_intervals_1dplot(xval, yval, xlabel, ylabel, kwargs): # Convert intervals to mid points and adjust labels if _valid_other_type(xval, [pd.Interval]): xval = _interval_to_mid_points(xval) - xlabel += "_center" + x_suffix = "_center" if _valid_other_type(yval, [pd.Interval]): yval = _interval_to_mid_points(yval) - ylabel += "_center" + y_suffix = "_center" # return converted arguments - return xval, yval, xlabel, ylabel, kwargs + return xval, yval, x_suffix, y_suffix, kwargs def _resolve_intervals_2dplot(val, func_name): diff --git a/xarray/tests/test_plot.py b/xarray/tests/test_plot.py index 2f4a4edd436..471bbb7051e 100644 --- a/xarray/tests/test_plot.py +++ b/xarray/tests/test_plot.py @@ -592,6 +592,20 @@ def test_coord_with_interval_xy(self): bins = [-1, 0, 1, 2] self.darray.groupby_bins("dim_0", bins).mean(...).dim_0_bins.plot() + @pytest.mark.parametrize("dim", ("x", "y")) + def test_labels_with_units_with_interval(self, dim): + """Test line plot with intervals and a units attribute.""" + bins = [-1, 0, 1, 2] + arr = self.darray.groupby_bins("dim_0", bins).mean(...) + arr.dim_0_bins.attrs["units"] = "m" + + (mappable,) = arr.plot(**{dim: "dim_0_bins"}) + ax = mappable.figure.gca() + actual = getattr(ax, f"get_{dim}label")() + + expected = "dim_0_bins_center [m]" + assert actual == expected + class TestPlot1D(PlotTestCase): @pytest.fixture(autouse=True)