From fd2c467fc15ba5f906660598402513f557010e92 Mon Sep 17 00:00:00 2001 From: Michael Niklas Date: Sat, 29 Jul 2023 22:30:48 +0200 Subject: [PATCH 01/20] fix several typing issues with mpl3.8 --- xarray/core/options.py | 4 +- xarray/plot/dataarray_plot.py | 12 +++- xarray/tests/test_plot.py | 104 ++++++++++++++++++++++------------ 3 files changed, 81 insertions(+), 39 deletions(-) diff --git a/xarray/core/options.py b/xarray/core/options.py index eb0c56c7ee0..828d56cd300 100644 --- a/xarray/core/options.py +++ b/xarray/core/options.py @@ -164,11 +164,11 @@ class set_options: cmap_divergent : str or matplotlib.colors.Colormap, default: "RdBu_r" Colormap to use for divergent data plots. If string, must be matplotlib built-in colormap. Can also be a Colormap object - (e.g. mpl.cm.magma) + (e.g. mpl.colormaps["magma"]) cmap_sequential : str or matplotlib.colors.Colormap, default: "viridis" Colormap to use for nondivergent data plots. If string, must be matplotlib built-in colormap. Can also be a Colormap object - (e.g. mpl.cm.magma) + (e.g. mpl.colormaps["magma"]) display_expand_attrs : {"default", True, False} Whether to expand the attributes section for display of ``DataArray`` or ``Dataset`` objects. Can be diff --git a/xarray/plot/dataarray_plot.py b/xarray/plot/dataarray_plot.py index d2c0a8e2af6..4318b3db7b2 100644 --- a/xarray/plot/dataarray_plot.py +++ b/xarray/plot/dataarray_plot.py @@ -992,9 +992,13 @@ def newplotfunc( with plt.rc_context(_styles): if z is not None: + import mpl_toolkits + if ax is None: subplot_kws.update(projection="3d") ax = get_axis(figsize, size, aspect, ax, **subplot_kws) + assert isinstance(ax, mpl_toolkits.mplot3d.axes3d.Axes3D) + # Using 30, 30 minimizes rotation of the plot. Making it easier to # build on your intuition from 2D plots: ax.view_init(azim=30, elev=30, vertical_axis="y") @@ -1261,8 +1265,8 @@ def scatter( plts_dict: dict[str, DataArray | None] = dict(x=xplt, y=yplt, z=zplt) plts_or_none = [plts_dict[v] for v in axis_order] - plts = [p for p in plts_or_none if p is not None] - primitive = ax.scatter(*[p.to_numpy().ravel() for p in plts], **kwargs) + plts = [p.to_numpy().ravel() for p in plts_or_none if p is not None] + primitive = ax.scatter(*plts, **kwargs) _add_labels(add_labels, plts, ("", "", ""), (True, False, False), ax) return primitive @@ -1616,6 +1620,7 @@ def newplotfunc( ax.set_ylabel(label_from_attrs(darray[ylab], ylab_extra)) ax.set_title(darray._title_for_slice()) if plotfunc.__name__ == "surface": + assert isinstance(ax, mpl_toolkits.mplot3d.axes3d.Axes3D) ax.set_zlabel(label_from_attrs(darray)) if add_colorbar: @@ -2465,5 +2470,8 @@ def surface( Wraps :py:meth:`matplotlib:mpl_toolkits.mplot3d.axes3d.Axes3D.plot_surface`. """ + import mpl_toolkits + + assert isinstance(ax, mpl_toolkits.mplot3d.axes3d.Axes3D) primitive = ax.plot_surface(x, y, z, **kwargs) return primitive diff --git a/xarray/tests/test_plot.py b/xarray/tests/test_plot.py index 8b2dfbdec41..fc80c60e00f 100644 --- a/xarray/tests/test_plot.py +++ b/xarray/tests/test_plot.py @@ -43,6 +43,7 @@ # import mpl and change the backend before other mpl imports try: import matplotlib as mpl + import matplotlib.dates import matplotlib.pyplot as plt import mpl_toolkits except ImportError: @@ -421,6 +422,7 @@ def test2d_1d_2d_coordinates_pcolormesh(self) -> None: ]: p = a.plot.pcolormesh(x=x, y=y) v = p.get_paths()[0].vertices + assert isinstance(v, np.ndarray) # Check all vertices are different, except last vertex which should be the # same as the first @@ -440,7 +442,7 @@ def test_str_coordinates_pcolormesh(self) -> None: def test_contourf_cmap_set(self) -> None: a = DataArray(easy_array((4, 4)), dims=["z", "time"]) - cmap = mpl.cm.viridis + cmap_expected = mpl.colormaps["viridis"] # use copy to ensure cmap is not changed by contourf() # Set vmin and vmax so that _build_discrete_colormap is called with @@ -450,55 +452,59 @@ def test_contourf_cmap_set(self) -> None: # extend='neither' (but if extend='neither' the under and over values # would not be used because the data would all be within the plotted # range) - pl = a.plot.contourf(cmap=copy(cmap), vmin=0.1, vmax=0.9) + pl = a.plot.contourf(cmap=copy(cmap_expected), vmin=0.1, vmax=0.9) # check the set_bad color + cmap = pl.cmap + assert cmap is not None assert_array_equal( - pl.cmap(np.ma.masked_invalid([np.nan]))[0], cmap(np.ma.masked_invalid([np.nan]))[0], + cmap_expected(np.ma.masked_invalid([np.nan]))[0], ) # check the set_under color - assert pl.cmap(-np.inf) == cmap(-np.inf) + assert cmap(-np.inf) == cmap_expected(-np.inf) # check the set_over color - assert pl.cmap(np.inf) == cmap(np.inf) + assert cmap(np.inf) == cmap_expected(np.inf) def test_contourf_cmap_set_with_bad_under_over(self) -> None: a = DataArray(easy_array((4, 4)), dims=["z", "time"]) # make a copy here because we want a local cmap that we will modify. - cmap = copy(mpl.cm.viridis) + cmap_expected = copy(mpl.colormaps["viridis"]) - cmap.set_bad("w") + cmap_expected.set_bad("w") # check we actually changed the set_bad color assert np.all( - cmap(np.ma.masked_invalid([np.nan]))[0] - != mpl.cm.viridis(np.ma.masked_invalid([np.nan]))[0] + cmap_expected(np.ma.masked_invalid([np.nan]))[0] + != mpl.colormaps["viridis"](np.ma.masked_invalid([np.nan]))[0] ) - cmap.set_under("r") + cmap_expected.set_under("r") # check we actually changed the set_under color - assert cmap(-np.inf) != mpl.cm.viridis(-np.inf) + assert cmap_expected(-np.inf) != mpl.colormaps["viridis"](-np.inf) - cmap.set_over("g") + cmap_expected.set_over("g") # check we actually changed the set_over color - assert cmap(np.inf) != mpl.cm.viridis(-np.inf) + assert cmap_expected(np.inf) != mpl.colormaps["viridis"](-np.inf) # copy to ensure cmap is not changed by contourf() - pl = a.plot.contourf(cmap=copy(cmap)) + pl = a.plot.contourf(cmap=copy(cmap_expected)) + cmap = pl.cmap + assert cmap is not None # check the set_bad color has been kept assert_array_equal( - pl.cmap(np.ma.masked_invalid([np.nan]))[0], cmap(np.ma.masked_invalid([np.nan]))[0], + cmap_expected(np.ma.masked_invalid([np.nan]))[0], ) # check the set_under color has been kept - assert pl.cmap(-np.inf) == cmap(-np.inf) + assert cmap(-np.inf) == cmap_expected(-np.inf) # check the set_over color has been kept - assert pl.cmap(np.inf) == cmap(np.inf) + assert cmap(np.inf) == cmap_expected(np.inf) def test3d(self) -> None: self.darray.plot() @@ -831,19 +837,25 @@ def test_coord_with_interval_step(self) -> None: """Test step plot with intervals.""" bins = [-1, 0, 1, 2] self.darray.groupby_bins("dim_0", bins).mean(...).plot.step() - assert len(plt.gca().lines[0].get_xdata()) == ((len(bins) - 1) * 2) + line = plt.gca().lines[0] + assert isinstance(line, mpl.lines.Line2D) + assert len(np.asarray(line.get_xdata())) == ((len(bins) - 1) * 2) def test_coord_with_interval_step_x(self) -> None: """Test step plot with intervals explicitly on x axis.""" bins = [-1, 0, 1, 2] self.darray.groupby_bins("dim_0", bins).mean(...).plot.step(x="dim_0_bins") - assert len(plt.gca().lines[0].get_xdata()) == ((len(bins) - 1) * 2) + line = plt.gca().lines[0] + assert isinstance(line, mpl.lines.Line2D) + assert len(np.asarray(line.get_xdata())) == ((len(bins) - 1) * 2) def test_coord_with_interval_step_y(self) -> None: """Test step plot with intervals explicitly on y axis.""" bins = [-1, 0, 1, 2] self.darray.groupby_bins("dim_0", bins).mean(...).plot.step(y="dim_0_bins") - assert len(plt.gca().lines[0].get_xdata()) == ((len(bins) - 1) * 2) + line = plt.gca().lines[0] + assert isinstance(line, mpl.lines.Line2D) + assert len(np.asarray(line.get_xdata())) == ((len(bins) - 1) * 2) def test_coord_with_interval_step_x_and_y_raises_valueeerror(self) -> None: """Test that step plot with intervals both on x and y axes raises an error.""" @@ -928,9 +940,9 @@ def test_cmap_sequential_option(self) -> None: assert cmap_params["cmap"] == "magma" def test_cmap_sequential_explicit_option(self) -> None: - with xr.set_options(cmap_sequential=mpl.cm.magma): + with xr.set_options(cmap_sequential=mpl.colormaps["magma"]): cmap_params = _determine_cmap_params(self.data) - assert cmap_params["cmap"] == mpl.cm.magma + assert cmap_params["cmap"] == mpl.colormaps["magma"] def test_cmap_divergent_option(self) -> None: with xr.set_options(cmap_divergent="magma"): @@ -1170,7 +1182,7 @@ def test_discrete_colormap_list_of_levels(self) -> None: def test_discrete_colormap_int_levels(self) -> None: for extend, levels, vmin, vmax, cmap in [ ("neither", 7, None, None, None), - ("neither", 7, None, 20, mpl.cm.RdBu), + ("neither", 7, None, 20, mpl.colormaps["RdBu"]), ("both", 7, 4, 8, None), ("min", 10, 4, 15, None), ]: @@ -1720,8 +1732,8 @@ class TestContour(Common2dMixin, PlotTestCase): # matplotlib cmap.colors gives an rgbA ndarray # when seaborn is used, instead we get an rgb tuple @staticmethod - def _color_as_tuple(c): - return tuple(c[:3]) + def _color_as_tuple(c: Any) -> tuple[Any, Any, Any]: + return c[0], c[1], c[2] def test_colors(self) -> None: # with single color, we don't want rgb array @@ -1743,10 +1755,15 @@ def test_colors_np_levels(self) -> None: # https://github.com/pydata/xarray/issues/3284 levels = np.array([-0.5, 0.0, 0.5, 1.0]) artist = self.darray.plot.contour(levels=levels, colors=["k", "r", "w", "b"]) - assert self._color_as_tuple(artist.cmap.colors[1]) == (1.0, 0.0, 0.0) - assert self._color_as_tuple(artist.cmap.colors[2]) == (1.0, 1.0, 1.0) + cmap = artist.cmap + assert isinstance(cmap, mpl.colors.ListedColormap) + colors = cmap.colors + assert isinstance(colors, list) + + assert self._color_as_tuple(colors[1]) == (1.0, 0.0, 0.0) + assert self._color_as_tuple(colors[2]) == (1.0, 1.0, 1.0) # the last color is now under "over" - assert self._color_as_tuple(artist.cmap._rgba_over) == (0.0, 0.0, 1.0) + assert self._color_as_tuple(cmap._rgba_over) == (0.0, 0.0, 1.0) # type: ignore[attr-defined] def test_cmap_and_color_both(self) -> None: with pytest.raises(ValueError): @@ -1798,7 +1815,9 @@ def test_dont_infer_interval_breaks_for_cartopy(self) -> None: artist = self.plotmethod(x="x2d", y="y2d", ax=ax) assert isinstance(artist, mpl.collections.QuadMesh) # Let cartopy handle the axis limits and artist size - assert artist.get_array().size <= self.darray.size + arr = artist.get_array() + assert arr is not None + assert arr.size <= self.darray.size class TestPcolormeshLogscale(PlotTestCase): @@ -1949,6 +1968,7 @@ def test_normalize_rgb_imshow( ) -> None: da = DataArray(easy_array((5, 5, 3), start=-0.6, stop=1.4)) arr = da.plot.imshow(vmin=vmin, vmax=vmax, robust=robust).get_array() + assert arr is not None assert 0 <= arr.min() <= arr.max() <= 1 def test_normalize_rgb_one_arg_error(self) -> None: @@ -1965,7 +1985,10 @@ def test_imshow_rgb_values_in_valid_range(self) -> None: da = DataArray(np.arange(75, dtype="uint8").reshape((5, 5, 3))) _, ax = plt.subplots() out = da.plot.imshow(ax=ax).get_array() - assert out.dtype == np.uint8 + assert out is not None + dtype = out.dtype + assert dtype is not None + assert dtype == np.uint8 assert (out[..., :3] == da.values).all() # Compare without added alpha @pytest.mark.filterwarnings("ignore:Several dimensions of this array") @@ -2000,6 +2023,7 @@ def test_2d_coord_names(self) -> None: self.plotmethod(x="x2d", y="y2d") # make sure labels came out ok ax = plt.gca() + assert isinstance(ax, mpl_toolkits.mplot3d.axes3d.Axes3D) assert "x2d" == ax.get_xlabel() assert "y2d" == ax.get_ylabel() assert f"{self.darray.long_name} [{self.darray.units}]" == ax.get_zlabel() @@ -2122,6 +2146,7 @@ def test_colorbar(self) -> None: self.g.map_dataarray(xplt.imshow, "x", "y") for image in plt.gcf().findobj(mpl.image.AxesImage): + assert isinstance(image, mpl.image.AxesImage) clim = np.array(image.get_clim()) assert np.allclose(expected, clim) @@ -2132,8 +2157,8 @@ def test_colorbar_scatter(self) -> None: fg: xplt.FacetGrid = ds.plot.scatter(x="a", y="a", row="x", hue="a") cbar = fg.cbar assert cbar is not None - assert cbar.vmin == 0 - assert cbar.vmax == 3 + assert cbar.vmin == 0 # type: ignore[attr-defined] + assert cbar.vmax == 3 # type: ignore[attr-defined] @pytest.mark.slow def test_empty_cell(self) -> None: @@ -2199,6 +2224,7 @@ def test_can_set_vmin_vmax(self) -> None: self.g.map_dataarray(xplt.imshow, "x", "y", vmin=vmin, vmax=vmax) for image in plt.gcf().findobj(mpl.image.AxesImage): + assert isinstance(image, mpl.image.AxesImage) clim = np.array(image.get_clim()) assert np.allclose(expected, clim) @@ -2215,6 +2241,7 @@ def test_can_set_norm(self) -> None: norm = mpl.colors.SymLogNorm(0.1) self.g.map_dataarray(xplt.imshow, "x", "y", norm=norm) for image in plt.gcf().findobj(mpl.image.AxesImage): + assert isinstance(image, mpl.image.AxesImage) assert image.norm is norm @pytest.mark.slow @@ -2752,15 +2779,19 @@ def test_non_numeric_legend(self) -> None: ds2 = self.ds.copy() ds2["hue"] = ["a", "b", "c", "d"] pc = ds2.plot.scatter(x="A", y="B", markersize="hue") + axes = pc.axes + assert axes is not None # should make a discrete legend - assert pc.axes.legend_ is not None + assert axes.legend_ is not None # type:ignore[attr-defined] def test_legend_labels(self) -> None: # regression test for #4126: incorrect legend labels ds2 = self.ds.copy() ds2["hue"] = ["a", "a", "b", "b"] pc = ds2.plot.scatter(x="A", y="B", markersize="hue") - actual = [t.get_text() for t in pc.axes.get_legend().texts] + axes = pc.axes + assert axes is not None + actual = [t.get_text() for t in axes.get_legend().texts] expected = ["hue", "a", "b"] assert actual == expected @@ -2781,7 +2812,9 @@ def test_legend_labels_facetgrid(self) -> None: def test_add_legend_by_default(self) -> None: sc = self.ds.plot.scatter(x="A", y="B", hue="hue") - assert len(sc.figure.axes) == 2 + fig = sc.figure + assert fig is not None + assert len(fig.axes) == 2 class TestDatetimePlot(PlotTestCase): @@ -2834,6 +2867,7 @@ def test_datetime_plot2d(self) -> None: p = da.plot.pcolormesh() ax = p.axes + assert ax is not None # Make sure only mpl converters are used, use type() so only # mpl.dates.AutoDateLocator passes and no other subclasses: From 3b62890c09487d904cab009129254d9f58008ae5 Mon Sep 17 00:00:00 2001 From: Michael Niklas Date: Tue, 5 Sep 2023 21:43:17 +0200 Subject: [PATCH 02/20] fix scatter plot typing and remove funky pyplot import --- xarray/plot/dataarray_plot.py | 32 ++++++++++++++++------------ xarray/plot/facetgrid.py | 5 ++--- xarray/plot/utils.py | 39 ++++++++++++++--------------------- 3 files changed, 37 insertions(+), 39 deletions(-) diff --git a/xarray/plot/dataarray_plot.py b/xarray/plot/dataarray_plot.py index 4318b3db7b2..0be6d01a34f 100644 --- a/xarray/plot/dataarray_plot.py +++ b/xarray/plot/dataarray_plot.py @@ -29,7 +29,6 @@ _resolve_intervals_2dplot, _update_axes, get_axis, - import_matplotlib_pyplot, label_from_attrs, ) @@ -879,7 +878,7 @@ def newplotfunc( # All 1d plots in xarray share this function signature. # Method signature below should be consistent. - plt = import_matplotlib_pyplot() + import matplotlib.pyplot as plt if subplot_kws is None: subplot_kws = dict() @@ -1082,12 +1081,12 @@ def newplotfunc( def _add_labels( add_labels: bool | Iterable[bool], - darrays: Iterable[DataArray], + darrays: Iterable[DataArray | None], suffixes: Iterable[str], rotate_labels: Iterable[bool], ax: Axes, ) -> None: - # Set x, y, z labels: + """Set x, y, z labels.""" add_labels = [add_labels] * 3 if isinstance(add_labels, bool) else add_labels for axis, add_label, darray, suffix, rotate_label in zip( ("x", "y", "z"), add_labels, darrays, suffixes, rotate_labels @@ -1261,15 +1260,24 @@ def scatter( if sizeplt is not None: kwargs.update(s=sizeplt.to_numpy().ravel()) - axis_order = ["x", "y", "z"] + plts_or_none = (xplt, yplt, zplt) + _add_labels(add_labels, plts_or_none, ("", "", ""), (True, False, False), ax) - plts_dict: dict[str, DataArray | None] = dict(x=xplt, y=yplt, z=zplt) - plts_or_none = [plts_dict[v] for v in axis_order] - plts = [p.to_numpy().ravel() for p in plts_or_none if p is not None] - primitive = ax.scatter(*plts, **kwargs) - _add_labels(add_labels, plts, ("", "", ""), (True, False, False), ax) + xplt_np = None if xplt is None else xplt.to_numpy().ravel() + yplt_np = None if yplt is None else yplt.to_numpy().ravel() + zplt_np = None if zplt is None else zplt.to_numpy().ravel() + plts_np = tuple(p for p in (xplt_np, yplt_np, zplt_np) if p is not None) - return primitive + if len(plts_np) == 3: + import mpl_toolkits + + assert isinstance(ax, mpl_toolkits.mplot3d.axes3d.Axes3D) + return ax.scatter(xplt_np, yplt_np, zplt_np, **kwargs) + + if len(plts_np) == 2: + return ax.scatter(plts_np[0], plts_np[1], **kwargs) + + raise ValueError("At least two variables required for a scatter plot.") def _plot2d(plotfunc): @@ -1506,8 +1514,6 @@ def newplotfunc( # TypeError to be consistent with pandas raise TypeError("No numeric data to plot.") - plt = import_matplotlib_pyplot() - if ( plotfunc.__name__ == "surface" and not kwargs.get("_is_facetgrid", False) diff --git a/xarray/plot/facetgrid.py b/xarray/plot/facetgrid.py index 93a328836d0..50f72d8906e 100644 --- a/xarray/plot/facetgrid.py +++ b/xarray/plot/facetgrid.py @@ -21,7 +21,6 @@ _Normalize, _parse_size, _process_cmap_cbar_kwargs, - import_matplotlib_pyplot, label_from_attrs, ) @@ -166,7 +165,7 @@ def __init__( """ - plt = import_matplotlib_pyplot() + import matplotlib.pyplot as plt # Handle corner case of nonunique coordinates rep_col = col is not None and not data[col].to_index().is_unique @@ -985,7 +984,7 @@ def map( self : FacetGrid object """ - plt = import_matplotlib_pyplot() + import matplotlib.pyplot as plt for ax, namedict in zip(self.axs.flat, self.name_dicts.flat): if namedict is not None: diff --git a/xarray/plot/utils.py b/xarray/plot/utils.py index 2c58fe83cef..42c94a051a4 100644 --- a/xarray/plot/utils.py +++ b/xarray/plot/utils.py @@ -47,14 +47,6 @@ _LINEWIDTH_RANGE = (1.5, 1.5, 6.0) -def import_matplotlib_pyplot(): - """import pyplot""" - # TODO: This function doesn't do anything (after #6109), remove it? - import matplotlib.pyplot as plt - - return plt - - def _determine_extend(calc_data, vmin, vmax): extend_min = calc_data.min() < vmin extend_max = calc_data.max() > vmax @@ -505,28 +497,29 @@ def _maybe_gca(**subplot_kws: Any) -> Axes: return plt.axes(**subplot_kws) -def _get_units_from_attrs(da) -> str: +def _get_units_from_attrs(da: DataArray) -> str: """Extracts and formats the unit/units from a attributes.""" pint_array_type = DuckArrayModule("pint").type units = " [{}]" if isinstance(da.data, pint_array_type): - units = units.format(str(da.data.units)) - elif da.attrs.get("units"): - units = units.format(da.attrs["units"]) - elif da.attrs.get("unit"): - units = units.format(da.attrs["unit"]) - else: - units = "" - return units + return units.format(str(da.data.units)) + if "units" in da.attrs: + return units.format(da.attrs["units"]) + if "unit" in da.attrs: + return units.format(da.attrs["unit"]) + return "" -def label_from_attrs(da, extra: str = "") -> str: +def label_from_attrs(da: DataArray | None, extra: str = "") -> str: """Makes informative labels if variable metadata (attrs) follows CF conventions.""" + if da is None: + return "" + name: str = "{}" - if da.attrs.get("long_name"): + if "long_name" in da.attrs: name = name.format(da.attrs["long_name"]) - elif da.attrs.get("standard_name"): + elif "standard_name" in da.attrs: name = name.format(da.attrs["standard_name"]) elif da.name is not None: name = name.format(da.name) @@ -1166,7 +1159,7 @@ def _get_color_and_size(value): def _legend_add_subtitle(handles, labels, text): """Add a subtitle to legend handles.""" - plt = import_matplotlib_pyplot() + import matplotlib.pyplot as plt if text and len(handles) > 1: # Create a blank handle that's not visible, the @@ -1184,7 +1177,7 @@ def _legend_add_subtitle(handles, labels, text): def _adjust_legend_subtitles(legend): """Make invisible-handle "subtitles" entries look more like titles.""" - plt = import_matplotlib_pyplot() + import matplotlib.pyplot as plt # Legend title not in rcParams until 3.0 font_size = plt.rcParams.get("legend.title_fontsize", None) @@ -1640,7 +1633,7 @@ def format(self) -> FuncFormatter: >>> aa.format(1) '3.0' """ - plt = import_matplotlib_pyplot() + import matplotlib.pyplot as plt def _func(x: Any, pos: None | Any = None): return f"{self._lookup_arr([x])[0]}" From 7b44b1b48c1b10a6f2aa0b266dfea793768e3017 Mon Sep 17 00:00:00 2001 From: Michael Niklas Date: Tue, 5 Sep 2023 22:03:17 +0200 Subject: [PATCH 03/20] fix some more typing errors --- xarray/plot/dataarray_plot.py | 2 +- xarray/plot/dataset_plot.py | 2 +- xarray/plot/facetgrid.py | 2 ++ 3 files changed, 4 insertions(+), 2 deletions(-) diff --git a/xarray/plot/dataarray_plot.py b/xarray/plot/dataarray_plot.py index 0be6d01a34f..b89fa7bdf7b 100644 --- a/xarray/plot/dataarray_plot.py +++ b/xarray/plot/dataarray_plot.py @@ -52,7 +52,7 @@ ) from xarray.plot.facetgrid import FacetGrid -_styles: MutableMapping[str, Any] = { +_styles: dict[str, Any] = { # Add a white border to make it easier seeing overlapping markers: "scatter.edgecolors": "w", } diff --git a/xarray/plot/dataset_plot.py b/xarray/plot/dataset_plot.py index b0774c31b17..7a70f1d41b2 100644 --- a/xarray/plot/dataset_plot.py +++ b/xarray/plot/dataset_plot.py @@ -632,7 +632,7 @@ def streamplot( du = du.transpose(ydim, xdim) dv = dv.transpose(ydim, xdim) - args = [dx.values, dy.values, du.values, dv.values] + args = (dx.values, dy.values, du.values, dv.values) hue = kwargs.pop("hue") cmap_params = kwargs.pop("cmap_params") diff --git a/xarray/plot/facetgrid.py b/xarray/plot/facetgrid.py index 50f72d8906e..0d1fc6c738a 100644 --- a/xarray/plot/facetgrid.py +++ b/xarray/plot/facetgrid.py @@ -730,6 +730,8 @@ def add_colorbar(self, **kwargs: Any) -> None: if hasattr(self._mappables[-1], "extend"): kwargs.pop("extend", None) if "label" not in kwargs: + data = self.data + assert isinstance(data, DataArray) kwargs.setdefault("label", label_from_attrs(self.data)) self.cbar = self.fig.colorbar( self._mappables[-1], ax=list(self.axs.flat), **kwargs From 9051f303d3ef6bbb22a76b28e4ecf6cd162bc4a5 Mon Sep 17 00:00:00 2001 From: Michael Niklas Date: Thu, 7 Sep 2023 21:13:36 +0200 Subject: [PATCH 04/20] fix some import errors --- xarray/plot/facetgrid.py | 2 -- xarray/plot/utils.py | 2 +- 2 files changed, 1 insertion(+), 3 deletions(-) diff --git a/xarray/plot/facetgrid.py b/xarray/plot/facetgrid.py index 0d1fc6c738a..50f72d8906e 100644 --- a/xarray/plot/facetgrid.py +++ b/xarray/plot/facetgrid.py @@ -730,8 +730,6 @@ def add_colorbar(self, **kwargs: Any) -> None: if hasattr(self._mappables[-1], "extend"): kwargs.pop("extend", None) if "label" not in kwargs: - data = self.data - assert isinstance(data, DataArray) kwargs.setdefault("label", label_from_attrs(self.data)) self.cbar = self.fig.colorbar( self._mappables[-1], ax=list(self.axs.flat), **kwargs diff --git a/xarray/plot/utils.py b/xarray/plot/utils.py index 42c94a051a4..91f9e07eb75 100644 --- a/xarray/plot/utils.py +++ b/xarray/plot/utils.py @@ -510,7 +510,7 @@ def _get_units_from_attrs(da: DataArray) -> str: return "" -def label_from_attrs(da: DataArray | None, extra: str = "") -> str: +def label_from_attrs(da: Dataset | DataArray | None, extra: str = "") -> str: """Makes informative labels if variable metadata (attrs) follows CF conventions.""" if da is None: From 55a63b6d9863de713af046cd9a09fb390f7345b9 Mon Sep 17 00:00:00 2001 From: Michael Niklas Date: Thu, 7 Sep 2023 22:09:51 +0200 Subject: [PATCH 05/20] undo some typing errors --- xarray/core/options.py | 6 ++---- xarray/plot/facetgrid.py | 5 ++++- xarray/plot/utils.py | 2 +- 3 files changed, 7 insertions(+), 6 deletions(-) diff --git a/xarray/core/options.py b/xarray/core/options.py index 828d56cd300..a197cb4da10 100644 --- a/xarray/core/options.py +++ b/xarray/core/options.py @@ -6,10 +6,8 @@ from xarray.core.utils import FrozenDict if TYPE_CHECKING: - try: - from matplotlib.colors import Colormap - except ImportError: - Colormap = str + from matplotlib.colors import Colormap + Options = Literal[ "arithmetic_join", "cmap_divergent", diff --git a/xarray/plot/facetgrid.py b/xarray/plot/facetgrid.py index 50f72d8906e..05670b269c3 100644 --- a/xarray/plot/facetgrid.py +++ b/xarray/plot/facetgrid.py @@ -680,7 +680,7 @@ def _finalize_grid(self, *axlabels: Hashable) -> None: def _adjust_fig_for_guide(self, guide) -> None: # Draw the plot to set the bounding boxes correctly - renderer = self.fig.canvas.get_renderer() + renderer = self.fig.canvas.get_renderer() # type: ignore[attr-defined] self.fig.draw(renderer) # Calculate and set the new width of the figure so the legend fits @@ -730,6 +730,9 @@ def add_colorbar(self, **kwargs: Any) -> None: if hasattr(self._mappables[-1], "extend"): kwargs.pop("extend", None) if "label" not in kwargs: + from xarray import DataArray + + assert isinstance(self.data, DataArray) kwargs.setdefault("label", label_from_attrs(self.data)) self.cbar = self.fig.colorbar( self._mappables[-1], ax=list(self.axs.flat), **kwargs diff --git a/xarray/plot/utils.py b/xarray/plot/utils.py index 91f9e07eb75..42c94a051a4 100644 --- a/xarray/plot/utils.py +++ b/xarray/plot/utils.py @@ -510,7 +510,7 @@ def _get_units_from_attrs(da: DataArray) -> str: return "" -def label_from_attrs(da: Dataset | DataArray | None, extra: str = "") -> str: +def label_from_attrs(da: DataArray | None, extra: str = "") -> str: """Makes informative labels if variable metadata (attrs) follows CF conventions.""" if da is None: From 9e2c352662e5a3eb57ee7cb4ef49f625c0658de7 Mon Sep 17 00:00:00 2001 From: Michael Niklas Date: Thu, 7 Sep 2023 22:28:51 +0200 Subject: [PATCH 06/20] fix xylim typing --- xarray/plot/dataarray_plot.py | 20 ++++++++++---------- xarray/plot/dataset_plot.py | 5 +++-- xarray/plot/utils.py | 4 ++-- 3 files changed, 15 insertions(+), 14 deletions(-) diff --git a/xarray/plot/dataarray_plot.py b/xarray/plot/dataarray_plot.py index b89fa7bdf7b..58b63d87f4b 100644 --- a/xarray/plot/dataarray_plot.py +++ b/xarray/plot/dataarray_plot.py @@ -409,8 +409,8 @@ def line( yscale: ScaleOptions = None, xticks: ArrayLike | None = None, yticks: ArrayLike | None = None, - xlim: ArrayLike | None = None, - ylim: ArrayLike | None = None, + xlim: tuple[float, float] | None = None, + ylim: tuple[float, float] | None = None, add_legend: bool = True, _labels: bool = True, **kwargs: Any, @@ -458,7 +458,7 @@ def line( Specifies scaling for the *x*- and *y*-axis, respectively. xticks, yticks : array-like, optional Specify tick locations for *x*- and *y*-axis. - xlim, ylim : array-like, optional + xlim, ylim : tuple[float, float], optional Specify *x*- and *y*-axis limits. add_legend : bool, default: True Add legend with *y* axis coordinates (2D inputs only). @@ -653,8 +653,8 @@ def hist( yscale: ScaleOptions = None, xticks: ArrayLike | None = None, yticks: ArrayLike | None = None, - xlim: ArrayLike | None = None, - ylim: ArrayLike | None = None, + xlim: tuple[float, float] | None = None, + ylim: tuple[float, float] | None = None, **kwargs: Any, ) -> tuple[np.ndarray, np.ndarray, BarContainer]: """ @@ -690,7 +690,7 @@ def hist( Specifies scaling for the *x*- and *y*-axis, respectively. xticks, yticks : array-like, optional Specify tick locations for *x*- and *y*-axis. - xlim, ylim : array-like, optional + xlim, ylim : tuple[float, float], optional Specify *x*- and *y*-axis limits. **kwargs : optional Additional keyword arguments to :py:func:`matplotlib:matplotlib.pyplot.hist`. @@ -1386,9 +1386,9 @@ def _plot2d(plotfunc): Specify tick locations for x-axes. yticks : ArrayLike or None, optional Specify tick locations for y-axes. - xlim : ArrayLike or None, optional + xlim : tuple[float, float] or None, optional Specify x-axes limits. - ylim : ArrayLike or None, optional + ylim : tuple[float, float] or None, optional Specify y-axes limits. norm : matplotlib.colors.Normalize, optional If ``norm`` has ``vmin`` or ``vmax`` specified, the corresponding @@ -1441,8 +1441,8 @@ def newplotfunc( yscale: ScaleOptions = None, xticks: ArrayLike | None = None, yticks: ArrayLike | None = None, - xlim: ArrayLike | None = None, - ylim: ArrayLike | None = None, + xlim: tuple[float, float] | None = None, + ylim: tuple[float, float] | None = None, norm: Normalize | None = None, **kwargs: Any, ) -> Any: diff --git a/xarray/plot/dataset_plot.py b/xarray/plot/dataset_plot.py index 7a70f1d41b2..b8d23c13901 100644 --- a/xarray/plot/dataset_plot.py +++ b/xarray/plot/dataset_plot.py @@ -632,7 +632,6 @@ def streamplot( du = du.transpose(ydim, xdim) dv = dv.transpose(ydim, xdim) - args = (dx.values, dy.values, du.values, dv.values) hue = kwargs.pop("hue") cmap_params = kwargs.pop("cmap_params") @@ -646,7 +645,9 @@ def streamplot( ) kwargs.pop("hue_style") - hdl = ax.streamplot(*args, **kwargs, **cmap_params) + hdl = ax.streamplot( + dx.values, dy.values, du.values, dv.values, **kwargs, **cmap_params + ) # Return .lines so colorbar creation works properly return hdl.lines diff --git a/xarray/plot/utils.py b/xarray/plot/utils.py index 42c94a051a4..243154e79f3 100644 --- a/xarray/plot/utils.py +++ b/xarray/plot/utils.py @@ -767,8 +767,8 @@ def _update_axes( yscale: ScaleOptions = None, xticks: ArrayLike | None = None, yticks: ArrayLike | None = None, - xlim: ArrayLike | None = None, - ylim: ArrayLike | None = None, + xlim: tuple[float, float] | None = None, + ylim: tuple[float, float] | None = None, ) -> None: """ Update axes with provided parameters From 42d33569e9fad3a1edc9b2eac22d7c3ee9448d7e Mon Sep 17 00:00:00 2001 From: Michael Niklas Date: Mon, 11 Sep 2023 21:02:34 +0200 Subject: [PATCH 07/20] add forgotten import --- xarray/plot/dataarray_plot.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/xarray/plot/dataarray_plot.py b/xarray/plot/dataarray_plot.py index 58b63d87f4b..fdb6b4f4f59 100644 --- a/xarray/plot/dataarray_plot.py +++ b/xarray/plot/dataarray_plot.py @@ -1626,6 +1626,8 @@ def newplotfunc( ax.set_ylabel(label_from_attrs(darray[ylab], ylab_extra)) ax.set_title(darray._title_for_slice()) if plotfunc.__name__ == "surface": + import mpl_toolkits + assert isinstance(ax, mpl_toolkits.mplot3d.axes3d.Axes3D) ax.set_zlabel(label_from_attrs(darray)) From b5d224980ad5123c616caa72c4b1d59b25f827cf Mon Sep 17 00:00:00 2001 From: Michael Niklas Date: Mon, 11 Sep 2023 21:15:21 +0200 Subject: [PATCH 08/20] ignore plotting overloads because None is Hashable --- xarray/plot/accessor.py | 16 ++++++++-------- xarray/plot/dataarray_plot.py | 10 +++++----- xarray/plot/dataset_plot.py | 6 +++--- 3 files changed, 16 insertions(+), 16 deletions(-) diff --git a/xarray/plot/accessor.py b/xarray/plot/accessor.py index ff707602545..9c0cf6fb699 100644 --- a/xarray/plot/accessor.py +++ b/xarray/plot/accessor.py @@ -179,7 +179,7 @@ def step(self, *args, **kwargs) -> list[Line3D] | FacetGrid[DataArray]: return dataarray_plot.step(self._da, *args, **kwargs) @overload - def scatter( + def scatter( # type: ignore[misc] # None is hashable :( self, *args: Any, x: Hashable | None = None, @@ -306,7 +306,7 @@ def scatter(self, *args, **kwargs): return dataarray_plot.scatter(self._da, *args, **kwargs) @overload - def imshow( + def imshow( # type: ignore[misc] # None is hashable :( self, *args: Any, x: Hashable | None = None, @@ -430,7 +430,7 @@ def imshow(self, *args, **kwargs) -> AxesImage: return dataarray_plot.imshow(self._da, *args, **kwargs) @overload - def contour( + def contour( # type: ignore[misc] # None is hashable :( self, *args: Any, x: Hashable | None = None, @@ -554,7 +554,7 @@ def contour(self, *args, **kwargs) -> QuadContourSet: return dataarray_plot.contour(self._da, *args, **kwargs) @overload - def contourf( + def contourf( # type: ignore[misc] # None is hashable :( self, *args: Any, x: Hashable | None = None, @@ -678,7 +678,7 @@ def contourf(self, *args, **kwargs) -> QuadContourSet: return dataarray_plot.contourf(self._da, *args, **kwargs) @overload - def pcolormesh( + def pcolormesh( # type: ignore[misc] # None is hashable :( self, *args: Any, x: Hashable | None = None, @@ -945,7 +945,7 @@ def __call__(self, *args, **kwargs) -> NoReturn: ) @overload - def scatter( + def scatter( # type: ignore[misc] # None is hashable :( self, *args: Any, x: Hashable | None = None, @@ -1072,7 +1072,7 @@ def scatter(self, *args, **kwargs) -> PathCollection | FacetGrid[DataArray]: return dataset_plot.scatter(self._ds, *args, **kwargs) @overload - def quiver( + def quiver( # type: ignore[misc] # None is hashable :( self, *args: Any, x: Hashable | None = None, @@ -1187,7 +1187,7 @@ def quiver(self, *args, **kwargs) -> Quiver | FacetGrid: return dataset_plot.quiver(self._ds, *args, **kwargs) @overload - def streamplot( + def streamplot( # type: ignore[misc] # None is hashable :( self, *args: Any, x: Hashable | None = None, diff --git a/xarray/plot/dataarray_plot.py b/xarray/plot/dataarray_plot.py index fdb6b4f4f59..d1db5e343f3 100644 --- a/xarray/plot/dataarray_plot.py +++ b/xarray/plot/dataarray_plot.py @@ -1110,7 +1110,7 @@ def _add_labels( @overload -def scatter( +def scatter( # type: ignore[misc] # None is hashable :( darray: DataArray, *args: Any, x: Hashable | None = None, @@ -1669,7 +1669,7 @@ def newplotfunc( @overload -def imshow( +def imshow( # type: ignore[misc] # None is hashable :( darray: DataArray, x: Hashable | None = None, y: Hashable | None = None, @@ -1888,7 +1888,7 @@ def _center_pixels(x): @overload -def contour( +def contour( # type: ignore[misc] # None is hashable :( darray: DataArray, x: Hashable | None = None, y: Hashable | None = None, @@ -2024,7 +2024,7 @@ def contour( @overload -def contourf( +def contourf( # type: ignore[misc] # None is hashable :( darray: DataArray, x: Hashable | None = None, y: Hashable | None = None, @@ -2160,7 +2160,7 @@ def contourf( @overload -def pcolormesh( +def pcolormesh( # type: ignore[misc] # None is hashable :( darray: DataArray, x: Hashable | None = None, y: Hashable | None = None, diff --git a/xarray/plot/dataset_plot.py b/xarray/plot/dataset_plot.py index b8d23c13901..920bdd5368a 100644 --- a/xarray/plot/dataset_plot.py +++ b/xarray/plot/dataset_plot.py @@ -321,7 +321,7 @@ def newplotfunc( @overload -def quiver( +def quiver( # type: ignore[misc] # None is hashable :( ds: Dataset, *args: Any, x: Hashable | None = None, @@ -475,7 +475,7 @@ def quiver( @overload -def streamplot( +def streamplot( # type: ignore[misc] # None is hashable :( ds: Dataset, *args: Any, x: Hashable | None = None, @@ -749,7 +749,7 @@ def _temp_dataarray(ds: Dataset, y: Hashable, locals_: dict[str, Any]) -> DataAr @overload -def scatter( +def scatter( # type: ignore[misc] # None is hashable :( ds: Dataset, *args: Any, x: Hashable | None = None, From d1cc21699173e369d80a39b9cfb486ca8b9fdb2e Mon Sep 17 00:00:00 2001 From: Michael Niklas Date: Mon, 11 Sep 2023 21:26:21 +0200 Subject: [PATCH 09/20] add whats-new entry --- doc/whats-new.rst | 3 +++ 1 file changed, 3 insertions(+) diff --git a/doc/whats-new.rst b/doc/whats-new.rst index 157795f08d1..6f672609328 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -36,6 +36,9 @@ Breaking changes extracts and add the indexes from another :py:class:`Coordinates` object passed via ``coords`` (:pull:`8107`). By `Benoît Bovy `_. +- Static typing of `xlim` and `ylim` arguments in plotting functions now must + be `tuple[float, float]` to align with matplotlib requirements. (:issue:`7802`, :pull:`8030`). + By `Michael Niklas `_. Deprecations ~~~~~~~~~~~~ From 428a1ed9d2479d13527329f53e45acf74699c0e4 Mon Sep 17 00:00:00 2001 From: Michael Niklas Date: Mon, 11 Sep 2023 21:41:05 +0200 Subject: [PATCH 10/20] fix return type of hist --- doc/whats-new.rst | 4 ++-- xarray/plot/accessor.py | 5 ++++- xarray/plot/dataarray_plot.py | 12 ++++++++---- xarray/tests/test_plot.py | 7 +++++-- 4 files changed, 19 insertions(+), 9 deletions(-) diff --git a/doc/whats-new.rst b/doc/whats-new.rst index 6f672609328..78939a9c9a4 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -36,8 +36,8 @@ Breaking changes extracts and add the indexes from another :py:class:`Coordinates` object passed via ``coords`` (:pull:`8107`). By `Benoît Bovy `_. -- Static typing of `xlim` and `ylim` arguments in plotting functions now must - be `tuple[float, float]` to align with matplotlib requirements. (:issue:`7802`, :pull:`8030`). +- Static typing of ``xlim`` and ``ylim`` arguments in plotting functions now must + be ``tuple[float, float]`` to align with matplotlib requirements. (:issue:`7802`, :pull:`8030`). By `Michael Niklas `_. Deprecations diff --git a/xarray/plot/accessor.py b/xarray/plot/accessor.py index 9c0cf6fb699..fbdead9cd3f 100644 --- a/xarray/plot/accessor.py +++ b/xarray/plot/accessor.py @@ -16,6 +16,7 @@ from matplotlib.container import BarContainer from matplotlib.contour import QuadContourSet from matplotlib.image import AxesImage + from matplotlib.patches import Polygon from matplotlib.quiver import Quiver from mpl_toolkits.mplot3d.art3d import Line3D, Poly3DCollection from numpy.typing import ArrayLike @@ -47,7 +48,9 @@ def __call__(self, **kwargs) -> Any: return dataarray_plot.plot(self._da, **kwargs) @functools.wraps(dataarray_plot.hist) - def hist(self, *args, **kwargs) -> tuple[np.ndarray, np.ndarray, BarContainer]: + def hist( + self, *args, **kwargs + ) -> tuple[np.ndarray, np.ndarray, BarContainer | Polygon]: return dataarray_plot.hist(self._da, *args, **kwargs) @overload diff --git a/xarray/plot/dataarray_plot.py b/xarray/plot/dataarray_plot.py index d1db5e343f3..4ba0dda856b 100644 --- a/xarray/plot/dataarray_plot.py +++ b/xarray/plot/dataarray_plot.py @@ -3,7 +3,7 @@ import functools import warnings from collections.abc import Hashable, Iterable, MutableMapping -from typing import TYPE_CHECKING, Any, Callable, Literal, cast, overload +from typing import TYPE_CHECKING, Any, Callable, Literal, Union, cast, overload import numpy as np import pandas as pd @@ -39,6 +39,7 @@ from matplotlib.container import BarContainer from matplotlib.contour import QuadContourSet from matplotlib.image import AxesImage + from matplotlib.patches import Polygon from mpl_toolkits.mplot3d.art3d import Line3D, Poly3DCollection from numpy.typing import ArrayLike @@ -656,7 +657,7 @@ def hist( xlim: tuple[float, float] | None = None, ylim: tuple[float, float] | None = None, **kwargs: Any, -) -> tuple[np.ndarray, np.ndarray, BarContainer]: +) -> tuple[np.ndarray, np.ndarray, BarContainer | Polygon]: """ Histogram of DataArray. @@ -707,14 +708,17 @@ def hist( no_nan = np.ravel(darray.to_numpy()) no_nan = no_nan[pd.notnull(no_nan)] - primitive = ax.hist(no_nan, **kwargs) + n, bins, patches = cast( + tuple[np.ndarray, np.ndarray, Union["BarContainer", "Polygon"]], + ax.hist(no_nan, **kwargs), + ) ax.set_title(darray._title_for_slice()) ax.set_xlabel(label_from_attrs(darray)) _update_axes(ax, xincrease, yincrease, xscale, yscale, xticks, yticks, xlim, ylim) - return primitive + return n, bins, patches def _plot1d(plotfunc): diff --git a/xarray/tests/test_plot.py b/xarray/tests/test_plot.py index fc80c60e00f..d1ee7aadc53 100644 --- a/xarray/tests/test_plot.py +++ b/xarray/tests/test_plot.py @@ -895,8 +895,11 @@ def test_can_pass_in_axis(self) -> None: self.pass_in_axis(self.darray.plot.hist) def test_primitive_returned(self) -> None: - h = self.darray.plot.hist() - assert isinstance(h[-1][0], mpl.patches.Rectangle) + n, bins, patches = self.darray.plot.hist() + assert isinstance(n, np.ndarray) + assert isinstance(bins, np.ndarray) + assert isinstance(patches, mpl.container.BarContainer) + assert isinstance(patches[0], mpl.patches.Rectangle) @pytest.mark.slow def test_plot_nans(self) -> None: From 2ab1402d769e6a89c6553bb2d25e8682e30d84aa Mon Sep 17 00:00:00 2001 From: Michael Niklas Date: Mon, 11 Sep 2023 21:44:45 +0200 Subject: [PATCH 11/20] fix another xylim type --- xarray/plot/dataarray_plot.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/xarray/plot/dataarray_plot.py b/xarray/plot/dataarray_plot.py index 4ba0dda856b..cac7ac87ebe 100644 --- a/xarray/plot/dataarray_plot.py +++ b/xarray/plot/dataarray_plot.py @@ -782,9 +782,9 @@ def _plot1d(plotfunc): Specify tick locations for x-axes. yticks : ArrayLike or None, optional Specify tick locations for y-axes. - xlim : ArrayLike or None, optional + xlim : tuple[float, float] or None, optional Specify x-axes limits. - ylim : ArrayLike or None, optional + ylim : tuple[float, float] or None, optional Specify y-axes limits. cmap : matplotlib colormap name or colormap, optional The mapping from data values to color space. Either a @@ -869,8 +869,8 @@ def newplotfunc( yscale: ScaleOptions = None, xticks: ArrayLike | None = None, yticks: ArrayLike | None = None, - xlim: ArrayLike | None = None, - ylim: ArrayLike | None = None, + xlim: tuple[float, float] | None = None, + ylim: tuple[float, float] | None = None, cmap: str | Colormap | None = None, vmin: float | None = None, vmax: float | None = None, From db0db6408c30c2e62868265c28135848cc35a4b1 Mon Sep 17 00:00:00 2001 From: Michael Niklas Date: Mon, 11 Sep 2023 21:48:01 +0200 Subject: [PATCH 12/20] fix some more xylim types --- xarray/plot/dataarray_plot.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/xarray/plot/dataarray_plot.py b/xarray/plot/dataarray_plot.py index cac7ac87ebe..30dc93e4edf 100644 --- a/xarray/plot/dataarray_plot.py +++ b/xarray/plot/dataarray_plot.py @@ -325,8 +325,8 @@ def line( # type: ignore[misc] # None is hashable :( yscale: ScaleOptions = None, xticks: ArrayLike | None = None, yticks: ArrayLike | None = None, - xlim: ArrayLike | None = None, - ylim: ArrayLike | None = None, + xlim: tuple[float, float] | None = None, + ylim: tuple[float, float] | None = None, add_legend: bool = True, _labels: bool = True, **kwargs: Any, @@ -336,7 +336,7 @@ def line( # type: ignore[misc] # None is hashable :( @overload def line( - darray, + darray: DataArray, *args: Any, row: Hashable, # wrap -> FacetGrid col: Hashable | None = None, @@ -353,8 +353,8 @@ def line( yscale: ScaleOptions = None, xticks: ArrayLike | None = None, yticks: ArrayLike | None = None, - xlim: ArrayLike | None = None, - ylim: ArrayLike | None = None, + xlim: tuple[float, float] | None = None, + ylim: tuple[float, float] | None = None, add_legend: bool = True, _labels: bool = True, **kwargs: Any, @@ -364,7 +364,7 @@ def line( @overload def line( - darray, + darray: DataArray, *args: Any, row: Hashable | None = None, col: Hashable, # wrap -> FacetGrid @@ -381,8 +381,8 @@ def line( yscale: ScaleOptions = None, xticks: ArrayLike | None = None, yticks: ArrayLike | None = None, - xlim: ArrayLike | None = None, - ylim: ArrayLike | None = None, + xlim: tuple[float, float] | None = None, + ylim: tuple[float, float] | None = None, add_legend: bool = True, _labels: bool = True, **kwargs: Any, From 8897ba1780d38e538a28c579c61c536ef788d7e6 Mon Sep 17 00:00:00 2001 From: Michael Niklas Date: Tue, 12 Sep 2023 20:12:47 +0200 Subject: [PATCH 13/20] change to T_DataArray --- xarray/plot/dataarray_plot.py | 60 +++++++++++++++++------------------ xarray/plot/facetgrid.py | 12 +++---- 2 files changed, 36 insertions(+), 36 deletions(-) diff --git a/xarray/plot/dataarray_plot.py b/xarray/plot/dataarray_plot.py index f09a59fa26b..67dee95d0ca 100644 --- a/xarray/plot/dataarray_plot.py +++ b/xarray/plot/dataarray_plot.py @@ -336,7 +336,7 @@ def line( # type: ignore[misc] # None is hashable :( @overload def line( - darray: DataArray, + darray: T_DataArray, *args: Any, row: Hashable, # wrap -> FacetGrid col: Hashable | None = None, @@ -358,13 +358,13 @@ def line( add_legend: bool = True, _labels: bool = True, **kwargs: Any, -) -> FacetGrid[DataArray]: +) -> FacetGrid[T_DataArray]: ... @overload def line( - darray: DataArray, + darray: T_DataArray, *args: Any, row: Hashable | None = None, col: Hashable, # wrap -> FacetGrid @@ -386,14 +386,14 @@ def line( add_legend: bool = True, _labels: bool = True, **kwargs: Any, -) -> FacetGrid[DataArray]: +) -> FacetGrid[T_DataArray]: ... # This function signature should not change so that it can use # matplotlib format strings def line( - darray: DataArray, + darray: T_DataArray, *args: Any, row: Hashable | None = None, col: Hashable | None = None, @@ -415,7 +415,7 @@ def line( add_legend: bool = True, _labels: bool = True, **kwargs: Any, -) -> list[Line3D] | FacetGrid[DataArray]: +) -> list[Line3D] | FacetGrid[T_DataArray]: """ Line plot of DataArray values. @@ -1157,7 +1157,7 @@ def scatter( # type: ignore[misc] # None is hashable :( @overload def scatter( - darray: DataArray, + darray: T_DataArray, *args: Any, x: Hashable | None = None, y: Hashable | None = None, @@ -1193,13 +1193,13 @@ def scatter( extend: ExtendOptions = None, levels: ArrayLike | None = None, **kwargs, -) -> FacetGrid[DataArray]: +) -> FacetGrid[T_DataArray]: ... @overload def scatter( - darray: DataArray, + darray: T_DataArray, *args: Any, x: Hashable | None = None, y: Hashable | None = None, @@ -1235,7 +1235,7 @@ def scatter( extend: ExtendOptions = None, levels: ArrayLike | None = None, **kwargs, -) -> FacetGrid[DataArray]: +) -> FacetGrid[T_DataArray]: ... @@ -1715,7 +1715,7 @@ def imshow( # type: ignore[misc] # None is hashable :( @overload def imshow( - darray: DataArray, + darray: T_DataArray, x: Hashable | None = None, y: Hashable | None = None, *, @@ -1750,13 +1750,13 @@ def imshow( ylim: ArrayLike | None = None, norm: Normalize | None = None, **kwargs: Any, -) -> FacetGrid[DataArray]: +) -> FacetGrid[T_DataArray]: ... @overload def imshow( - darray: DataArray, + darray: T_DataArray, x: Hashable | None = None, y: Hashable | None = None, *, @@ -1791,7 +1791,7 @@ def imshow( ylim: ArrayLike | None = None, norm: Normalize | None = None, **kwargs: Any, -) -> FacetGrid[DataArray]: +) -> FacetGrid[T_DataArray]: ... @@ -1934,7 +1934,7 @@ def contour( # type: ignore[misc] # None is hashable :( @overload def contour( - darray: DataArray, + darray: T_DataArray, x: Hashable | None = None, y: Hashable | None = None, *, @@ -1969,13 +1969,13 @@ def contour( ylim: ArrayLike | None = None, norm: Normalize | None = None, **kwargs: Any, -) -> FacetGrid[DataArray]: +) -> FacetGrid[T_DataArray]: ... @overload def contour( - darray: DataArray, + darray: T_DataArray, x: Hashable | None = None, y: Hashable | None = None, *, @@ -2010,7 +2010,7 @@ def contour( ylim: ArrayLike | None = None, norm: Normalize | None = None, **kwargs: Any, -) -> FacetGrid[DataArray]: +) -> FacetGrid[T_DataArray]: ... @@ -2070,7 +2070,7 @@ def contourf( # type: ignore[misc] # None is hashable :( @overload def contourf( - darray: DataArray, + darray: T_DataArray, x: Hashable | None = None, y: Hashable | None = None, *, @@ -2105,13 +2105,13 @@ def contourf( ylim: ArrayLike | None = None, norm: Normalize | None = None, **kwargs: Any, -) -> FacetGrid[DataArray]: +) -> FacetGrid[T_DataArray]: ... @overload def contourf( - darray: DataArray, + darray: T_DataArray, x: Hashable | None = None, y: Hashable | None = None, *, @@ -2146,7 +2146,7 @@ def contourf( ylim: ArrayLike | None = None, norm: Normalize | None = None, **kwargs: Any, -) -> FacetGrid[DataArray]: +) -> FacetGrid[T_DataArray]: ... @@ -2206,7 +2206,7 @@ def pcolormesh( # type: ignore[misc] # None is hashable :( @overload def pcolormesh( - darray: DataArray, + darray: T_DataArray, x: Hashable | None = None, y: Hashable | None = None, *, @@ -2241,13 +2241,13 @@ def pcolormesh( ylim: ArrayLike | None = None, norm: Normalize | None = None, **kwargs: Any, -) -> FacetGrid[DataArray]: +) -> FacetGrid[T_DataArray]: ... @overload def pcolormesh( - darray: DataArray, + darray: T_DataArray, x: Hashable | None = None, y: Hashable | None = None, *, @@ -2282,7 +2282,7 @@ def pcolormesh( ylim: ArrayLike | None = None, norm: Normalize | None = None, **kwargs: Any, -) -> FacetGrid[DataArray]: +) -> FacetGrid[T_DataArray]: ... @@ -2393,7 +2393,7 @@ def surface( @overload def surface( - darray: DataArray, + darray: T_DataArray, x: Hashable | None = None, y: Hashable | None = None, *, @@ -2428,13 +2428,13 @@ def surface( ylim: ArrayLike | None = None, norm: Normalize | None = None, **kwargs: Any, -) -> FacetGrid[DataArray]: +) -> FacetGrid[T_DataArray]: ... @overload def surface( - darray: DataArray, + darray: T_DataArray, x: Hashable | None = None, y: Hashable | None = None, *, @@ -2469,7 +2469,7 @@ def surface( ylim: ArrayLike | None = None, norm: Normalize | None = None, **kwargs: Any, -) -> FacetGrid[DataArray]: +) -> FacetGrid[T_DataArray]: ... diff --git a/xarray/plot/facetgrid.py b/xarray/plot/facetgrid.py index 419ab202c30..a8e7c9d674d 100644 --- a/xarray/plot/facetgrid.py +++ b/xarray/plot/facetgrid.py @@ -9,7 +9,7 @@ import numpy as np from xarray.core.formatting import format_item -from xarray.core.types import HueStyleOptions, T_Xarray +from xarray.core.types import HueStyleOptions, T_DataArrayOrSet from xarray.plot.utils import ( _LINEWIDTH_RANGE, _MARKERSIZE_RANGE, @@ -59,7 +59,7 @@ def _nicetitle(coord, value, maxchar, template): T_FacetGrid = TypeVar("T_FacetGrid", bound="FacetGrid") -class FacetGrid(Generic[T_Xarray]): +class FacetGrid(Generic[T_DataArrayOrSet]): """ Initialize the Matplotlib figure and FacetGrid object. @@ -100,7 +100,7 @@ class FacetGrid(Generic[T_Xarray]): sometimes the rightmost grid positions in the bottom row. """ - data: T_Xarray + data: T_DataArrayOrSet name_dicts: np.ndarray fig: Figure axs: np.ndarray @@ -125,7 +125,7 @@ class FacetGrid(Generic[T_Xarray]): def __init__( self, - data: T_Xarray, + data: T_DataArrayOrSet, col: Hashable | None = None, row: Hashable | None = None, col_wrap: int | None = None, @@ -1006,7 +1006,7 @@ def map( def _easy_facetgrid( - data: T_Xarray, + data: T_DataArrayOrSet, plotfunc: Callable, kind: Literal["line", "dataarray", "dataset", "plot1d"], x: Hashable | None = None, @@ -1022,7 +1022,7 @@ def _easy_facetgrid( ax: Axes | None = None, figsize: Iterable[float] | None = None, **kwargs: Any, -) -> FacetGrid[T_Xarray]: +) -> FacetGrid[T_DataArrayOrSet]: """ Convenience method to call xarray.plot.FacetGrid from 2d plotting methods From 64c2401fe0c1dcbabf40ea3c041040648e4259a7 Mon Sep 17 00:00:00 2001 From: Michael Niklas Date: Tue, 12 Sep 2023 20:18:40 +0200 Subject: [PATCH 14/20] change accessor xylim to tuple --- xarray/plot/accessor.py | 96 ++++++++++++++++++++--------------------- 1 file changed, 48 insertions(+), 48 deletions(-) diff --git a/xarray/plot/accessor.py b/xarray/plot/accessor.py index fbdead9cd3f..736792400d3 100644 --- a/xarray/plot/accessor.py +++ b/xarray/plot/accessor.py @@ -72,8 +72,8 @@ def line( # type: ignore[misc] # None is hashable :( yscale: ScaleOptions = None, xticks: ArrayLike | None = None, yticks: ArrayLike | None = None, - xlim: ArrayLike | None = None, - ylim: ArrayLike | None = None, + xlim: tuple[float, float] | None = None, + ylim: tuple[float, float] | None = None, add_legend: bool = True, _labels: bool = True, **kwargs: Any, @@ -99,8 +99,8 @@ def line( yscale: ScaleOptions = None, xticks: ArrayLike | None = None, yticks: ArrayLike | None = None, - xlim: ArrayLike | None = None, - ylim: ArrayLike | None = None, + xlim: tuple[float, float] | None = None, + ylim: tuple[float, float] | None = None, add_legend: bool = True, _labels: bool = True, **kwargs: Any, @@ -126,8 +126,8 @@ def line( yscale: ScaleOptions = None, xticks: ArrayLike | None = None, yticks: ArrayLike | None = None, - xlim: ArrayLike | None = None, - ylim: ArrayLike | None = None, + xlim: tuple[float, float] | None = None, + ylim: tuple[float, float] | None = None, add_legend: bool = True, _labels: bool = True, **kwargs: Any, @@ -210,8 +210,8 @@ def scatter( # type: ignore[misc] # None is hashable :( yscale: ScaleOptions = None, xticks: ArrayLike | None = None, yticks: ArrayLike | None = None, - xlim: ArrayLike | None = None, - ylim: ArrayLike | None = None, + xlim: tuple[float, float] | None = None, + ylim: tuple[float, float] | None = None, cmap=None, vmin: float | None = None, vmax: float | None = None, @@ -251,8 +251,8 @@ def scatter( yscale: ScaleOptions = None, xticks: ArrayLike | None = None, yticks: ArrayLike | None = None, - xlim: ArrayLike | None = None, - ylim: ArrayLike | None = None, + xlim: tuple[float, float] | None = None, + ylim: tuple[float, float] | None = None, cmap=None, vmin: float | None = None, vmax: float | None = None, @@ -292,8 +292,8 @@ def scatter( yscale: ScaleOptions = None, xticks: ArrayLike | None = None, yticks: ArrayLike | None = None, - xlim: ArrayLike | None = None, - ylim: ArrayLike | None = None, + xlim: tuple[float, float] | None = None, + ylim: tuple[float, float] | None = None, cmap=None, vmin: float | None = None, vmax: float | None = None, @@ -341,8 +341,8 @@ def imshow( # type: ignore[misc] # None is hashable :( yscale: ScaleOptions = None, xticks: ArrayLike | None = None, yticks: ArrayLike | None = None, - xlim: ArrayLike | None = None, - ylim: ArrayLike | None = None, + xlim: tuple[float, float] | None = None, + ylim: tuple[float, float] | None = None, norm: Normalize | None = None, **kwargs: Any, ) -> AxesImage: @@ -381,8 +381,8 @@ def imshow( yscale: ScaleOptions = None, xticks: ArrayLike | None = None, yticks: ArrayLike | None = None, - xlim: ArrayLike | None = None, - ylim: ArrayLike | None = None, + xlim: tuple[float, float] | None = None, + ylim: tuple[float, float] | None = None, norm: Normalize | None = None, **kwargs: Any, ) -> FacetGrid[DataArray]: @@ -421,8 +421,8 @@ def imshow( yscale: ScaleOptions = None, xticks: ArrayLike | None = None, yticks: ArrayLike | None = None, - xlim: ArrayLike | None = None, - ylim: ArrayLike | None = None, + xlim: tuple[float, float] | None = None, + ylim: tuple[float, float] | None = None, norm: Normalize | None = None, **kwargs: Any, ) -> FacetGrid[DataArray]: @@ -465,8 +465,8 @@ def contour( # type: ignore[misc] # None is hashable :( yscale: ScaleOptions = None, xticks: ArrayLike | None = None, yticks: ArrayLike | None = None, - xlim: ArrayLike | None = None, - ylim: ArrayLike | None = None, + xlim: tuple[float, float] | None = None, + ylim: tuple[float, float] | None = None, norm: Normalize | None = None, **kwargs: Any, ) -> QuadContourSet: @@ -505,8 +505,8 @@ def contour( yscale: ScaleOptions = None, xticks: ArrayLike | None = None, yticks: ArrayLike | None = None, - xlim: ArrayLike | None = None, - ylim: ArrayLike | None = None, + xlim: tuple[float, float] | None = None, + ylim: tuple[float, float] | None = None, norm: Normalize | None = None, **kwargs: Any, ) -> FacetGrid[DataArray]: @@ -545,8 +545,8 @@ def contour( yscale: ScaleOptions = None, xticks: ArrayLike | None = None, yticks: ArrayLike | None = None, - xlim: ArrayLike | None = None, - ylim: ArrayLike | None = None, + xlim: tuple[float, float] | None = None, + ylim: tuple[float, float] | None = None, norm: Normalize | None = None, **kwargs: Any, ) -> FacetGrid[DataArray]: @@ -589,8 +589,8 @@ def contourf( # type: ignore[misc] # None is hashable :( yscale: ScaleOptions = None, xticks: ArrayLike | None = None, yticks: ArrayLike | None = None, - xlim: ArrayLike | None = None, - ylim: ArrayLike | None = None, + xlim: tuple[float, float] | None = None, + ylim: tuple[float, float] | None = None, norm: Normalize | None = None, **kwargs: Any, ) -> QuadContourSet: @@ -629,8 +629,8 @@ def contourf( yscale: ScaleOptions = None, xticks: ArrayLike | None = None, yticks: ArrayLike | None = None, - xlim: ArrayLike | None = None, - ylim: ArrayLike | None = None, + xlim: tuple[float, float] | None = None, + ylim: tuple[float, float] | None = None, norm: Normalize | None = None, **kwargs: Any, ) -> FacetGrid[DataArray]: @@ -669,8 +669,8 @@ def contourf( yscale: ScaleOptions = None, xticks: ArrayLike | None = None, yticks: ArrayLike | None = None, - xlim: ArrayLike | None = None, - ylim: ArrayLike | None = None, + xlim: tuple[float, float] | None = None, + ylim: tuple[float, float] | None = None, norm: Normalize | None = None, **kwargs: Any, ) -> FacetGrid: @@ -713,8 +713,8 @@ def pcolormesh( # type: ignore[misc] # None is hashable :( yscale: ScaleOptions = None, xticks: ArrayLike | None = None, yticks: ArrayLike | None = None, - xlim: ArrayLike | None = None, - ylim: ArrayLike | None = None, + xlim: tuple[float, float] | None = None, + ylim: tuple[float, float] | None = None, norm: Normalize | None = None, **kwargs: Any, ) -> QuadMesh: @@ -753,8 +753,8 @@ def pcolormesh( yscale: ScaleOptions = None, xticks: ArrayLike | None = None, yticks: ArrayLike | None = None, - xlim: ArrayLike | None = None, - ylim: ArrayLike | None = None, + xlim: tuple[float, float] | None = None, + ylim: tuple[float, float] | None = None, norm: Normalize | None = None, **kwargs: Any, ) -> FacetGrid: @@ -793,8 +793,8 @@ def pcolormesh( yscale: ScaleOptions = None, xticks: ArrayLike | None = None, yticks: ArrayLike | None = None, - xlim: ArrayLike | None = None, - ylim: ArrayLike | None = None, + xlim: tuple[float, float] | None = None, + ylim: tuple[float, float] | None = None, norm: Normalize | None = None, **kwargs: Any, ) -> FacetGrid: @@ -837,8 +837,8 @@ def surface( yscale: ScaleOptions = None, xticks: ArrayLike | None = None, yticks: ArrayLike | None = None, - xlim: ArrayLike | None = None, - ylim: ArrayLike | None = None, + xlim: tuple[float, float] | None = None, + ylim: tuple[float, float] | None = None, norm: Normalize | None = None, **kwargs: Any, ) -> Poly3DCollection: @@ -877,8 +877,8 @@ def surface( yscale: ScaleOptions = None, xticks: ArrayLike | None = None, yticks: ArrayLike | None = None, - xlim: ArrayLike | None = None, - ylim: ArrayLike | None = None, + xlim: tuple[float, float] | None = None, + ylim: tuple[float, float] | None = None, norm: Normalize | None = None, **kwargs: Any, ) -> FacetGrid: @@ -917,8 +917,8 @@ def surface( yscale: ScaleOptions = None, xticks: ArrayLike | None = None, yticks: ArrayLike | None = None, - xlim: ArrayLike | None = None, - ylim: ArrayLike | None = None, + xlim: tuple[float, float] | None = None, + ylim: tuple[float, float] | None = None, norm: Normalize | None = None, **kwargs: Any, ) -> FacetGrid: @@ -976,8 +976,8 @@ def scatter( # type: ignore[misc] # None is hashable :( yscale: ScaleOptions = None, xticks: ArrayLike | None = None, yticks: ArrayLike | None = None, - xlim: ArrayLike | None = None, - ylim: ArrayLike | None = None, + xlim: tuple[float, float] | None = None, + ylim: tuple[float, float] | None = None, cmap=None, vmin: float | None = None, vmax: float | None = None, @@ -1017,8 +1017,8 @@ def scatter( yscale: ScaleOptions = None, xticks: ArrayLike | None = None, yticks: ArrayLike | None = None, - xlim: ArrayLike | None = None, - ylim: ArrayLike | None = None, + xlim: tuple[float, float] | None = None, + ylim: tuple[float, float] | None = None, cmap=None, vmin: float | None = None, vmax: float | None = None, @@ -1058,8 +1058,8 @@ def scatter( yscale: ScaleOptions = None, xticks: ArrayLike | None = None, yticks: ArrayLike | None = None, - xlim: ArrayLike | None = None, - ylim: ArrayLike | None = None, + xlim: tuple[float, float] | None = None, + ylim: tuple[float, float] | None = None, cmap=None, vmin: float | None = None, vmax: float | None = None, From 79a92f5e2875bf973dd01f28ba4eae6b7b6f6c05 Mon Sep 17 00:00:00 2001 From: Michael Niklas Date: Tue, 12 Sep 2023 21:36:04 +0200 Subject: [PATCH 15/20] add missing return types --- xarray/plot/accessor.py | 62 ++++++++++++++++++++++------------------- 1 file changed, 34 insertions(+), 28 deletions(-) diff --git a/xarray/plot/accessor.py b/xarray/plot/accessor.py index 736792400d3..d70aed9f01e 100644 --- a/xarray/plot/accessor.py +++ b/xarray/plot/accessor.py @@ -2,7 +2,13 @@ import functools from collections.abc import Hashable, Iterable -from typing import TYPE_CHECKING, Any, Literal, NoReturn, overload +from typing import ( + TYPE_CHECKING, + Any, + Literal, + NoReturn, + overload, +) import numpy as np @@ -134,7 +140,7 @@ def line( ) -> FacetGrid[DataArray]: ... - @functools.wraps(dataarray_plot.line) + @functools.wraps(dataarray_plot.line, assigned=("__doc__")) def line(self, *args, **kwargs) -> list[Line3D] | FacetGrid[DataArray]: return dataarray_plot.line(self._da, *args, **kwargs) @@ -177,7 +183,7 @@ def step( ) -> FacetGrid[DataArray]: ... - @functools.wraps(dataarray_plot.step) + @functools.wraps(dataarray_plot.step, assigned=("__doc__")) def step(self, *args, **kwargs) -> list[Line3D] | FacetGrid[DataArray]: return dataarray_plot.step(self._da, *args, **kwargs) @@ -304,8 +310,8 @@ def scatter( ) -> FacetGrid[DataArray]: ... - @functools.wraps(dataarray_plot.scatter) - def scatter(self, *args, **kwargs): + @functools.wraps(dataarray_plot.scatter, assigned=("__doc__")) + def scatter(self, *args, **kwargs) -> PathCollection | FacetGrid[DataArray]: return dataarray_plot.scatter(self._da, *args, **kwargs) @overload @@ -428,8 +434,8 @@ def imshow( ) -> FacetGrid[DataArray]: ... - @functools.wraps(dataarray_plot.imshow) - def imshow(self, *args, **kwargs) -> AxesImage: + @functools.wraps(dataarray_plot.imshow, assigned=("__doc__")) + def imshow(self, *args, **kwargs) -> AxesImage | FacetGrid[DataArray]: return dataarray_plot.imshow(self._da, *args, **kwargs) @overload @@ -552,8 +558,8 @@ def contour( ) -> FacetGrid[DataArray]: ... - @functools.wraps(dataarray_plot.contour) - def contour(self, *args, **kwargs) -> QuadContourSet: + @functools.wraps(dataarray_plot.contour, assigned=("__doc__")) + def contour(self, *args, **kwargs) -> QuadContourSet | FacetGrid[DataArray]: return dataarray_plot.contour(self._da, *args, **kwargs) @overload @@ -676,8 +682,8 @@ def contourf( ) -> FacetGrid: ... - @functools.wraps(dataarray_plot.contourf) - def contourf(self, *args, **kwargs) -> QuadContourSet: + @functools.wraps(dataarray_plot.contourf, assigned=("__doc__")) + def contourf(self, *args, **kwargs) -> QuadContourSet | FacetGrid[DataArray]: return dataarray_plot.contourf(self._da, *args, **kwargs) @overload @@ -757,7 +763,7 @@ def pcolormesh( ylim: tuple[float, float] | None = None, norm: Normalize | None = None, **kwargs: Any, - ) -> FacetGrid: + ) -> FacetGrid[DataArray]: ... @overload @@ -797,11 +803,11 @@ def pcolormesh( ylim: tuple[float, float] | None = None, norm: Normalize | None = None, **kwargs: Any, - ) -> FacetGrid: + ) -> FacetGrid[DataArray]: ... - @functools.wraps(dataarray_plot.pcolormesh) - def pcolormesh(self, *args, **kwargs) -> QuadMesh: + @functools.wraps(dataarray_plot.pcolormesh, assigned=("__doc__")) + def pcolormesh(self, *args, **kwargs) -> QuadMesh | FacetGrid[DataArray]: return dataarray_plot.pcolormesh(self._da, *args, **kwargs) @overload @@ -924,7 +930,7 @@ def surface( ) -> FacetGrid: ... - @functools.wraps(dataarray_plot.surface) + @functools.wraps(dataarray_plot.surface, assigned=("__doc__")) def surface(self, *args, **kwargs) -> Poly3DCollection: return dataarray_plot.surface(self._da, *args, **kwargs) @@ -1026,7 +1032,7 @@ def scatter( extend=None, levels=None, **kwargs: Any, - ) -> FacetGrid[DataArray]: + ) -> FacetGrid[Dataset]: ... @overload @@ -1067,11 +1073,11 @@ def scatter( extend=None, levels=None, **kwargs: Any, - ) -> FacetGrid[DataArray]: + ) -> FacetGrid[Dataset]: ... - @functools.wraps(dataset_plot.scatter) - def scatter(self, *args, **kwargs) -> PathCollection | FacetGrid[DataArray]: + @functools.wraps(dataset_plot.scatter, assigned=("__doc__")) + def scatter(self, *args, **kwargs) -> PathCollection | FacetGrid[Dataset]: return dataset_plot.scatter(self._ds, *args, **kwargs) @overload @@ -1145,7 +1151,7 @@ def quiver( extend=None, cmap=None, **kwargs: Any, - ) -> FacetGrid: + ) -> FacetGrid[Dataset]: ... @overload @@ -1182,11 +1188,11 @@ def quiver( extend=None, cmap=None, **kwargs: Any, - ) -> FacetGrid: + ) -> FacetGrid[Dataset]: ... - @functools.wraps(dataset_plot.quiver) - def quiver(self, *args, **kwargs) -> Quiver | FacetGrid: + @functools.wraps(dataset_plot.quiver, assigned=("__doc__")) + def quiver(self, *args, **kwargs) -> Quiver | FacetGrid[Dataset]: return dataset_plot.quiver(self._ds, *args, **kwargs) @overload @@ -1260,7 +1266,7 @@ def streamplot( extend=None, cmap=None, **kwargs: Any, - ) -> FacetGrid: + ) -> FacetGrid[Dataset]: ... @overload @@ -1297,9 +1303,9 @@ def streamplot( extend=None, cmap=None, **kwargs: Any, - ) -> FacetGrid: + ) -> FacetGrid[Dataset]: ... - @functools.wraps(dataset_plot.streamplot) - def streamplot(self, *args, **kwargs) -> LineCollection | FacetGrid: + @functools.wraps(dataset_plot.streamplot, assigned=("__doc__")) + def streamplot(self, *args, **kwargs) -> LineCollection | FacetGrid[Dataset]: return dataset_plot.streamplot(self._ds, *args, **kwargs) From b9dcc6adccd860ffbee8a507951a7d8138f90ecc Mon Sep 17 00:00:00 2001 From: Michael Niklas Date: Tue, 12 Sep 2023 21:56:27 +0200 Subject: [PATCH 16/20] fix a typing error only on new mpl --- xarray/plot/accessor.py | 8 +------- xarray/plot/facetgrid.py | 5 ++++- 2 files changed, 5 insertions(+), 8 deletions(-) diff --git a/xarray/plot/accessor.py b/xarray/plot/accessor.py index d70aed9f01e..7791da7c8b4 100644 --- a/xarray/plot/accessor.py +++ b/xarray/plot/accessor.py @@ -2,13 +2,7 @@ import functools from collections.abc import Hashable, Iterable -from typing import ( - TYPE_CHECKING, - Any, - Literal, - NoReturn, - overload, -) +from typing import TYPE_CHECKING, Any, Literal, NoReturn, overload import numpy as np diff --git a/xarray/plot/facetgrid.py b/xarray/plot/facetgrid.py index a8e7c9d674d..faf809a8a74 100644 --- a/xarray/plot/facetgrid.py +++ b/xarray/plot/facetgrid.py @@ -680,7 +680,10 @@ def _finalize_grid(self, *axlabels: Hashable) -> None: def _adjust_fig_for_guide(self, guide) -> None: # Draw the plot to set the bounding boxes correctly - renderer = self.fig.canvas.get_renderer() # type: ignore[attr-defined] + if hasattr(self.fig.canvas, "get_renderer"): + renderer = self.fig.canvas.get_renderer() + else: + raise RuntimeError("MPL backend has no renderer") self.fig.draw(renderer) # Calculate and set the new width of the figure so the legend fits From 1eef36138a07901a79b9d2d0c850c6a49a2873eb Mon Sep 17 00:00:00 2001 From: Michael Niklas Date: Tue, 12 Sep 2023 21:59:17 +0200 Subject: [PATCH 17/20] add unused-ignore to error codes for old mpl --- xarray/plot/accessor.py | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/xarray/plot/accessor.py b/xarray/plot/accessor.py index 7791da7c8b4..d842ef45ab3 100644 --- a/xarray/plot/accessor.py +++ b/xarray/plot/accessor.py @@ -54,7 +54,7 @@ def hist( return dataarray_plot.hist(self._da, *args, **kwargs) @overload - def line( # type: ignore[misc] # None is hashable :( + def line( # type: ignore[misc,unused-ignore] # None is hashable :( self, *args: Any, row: None = None, # no wrap -> primitive @@ -139,7 +139,7 @@ def line(self, *args, **kwargs) -> list[Line3D] | FacetGrid[DataArray]: return dataarray_plot.line(self._da, *args, **kwargs) @overload - def step( # type: ignore[misc] # None is hashable :( + def step( # type: ignore[misc,unused-ignore] # None is hashable :( self, *args: Any, where: Literal["pre", "post", "mid"] = "pre", @@ -182,7 +182,7 @@ def step(self, *args, **kwargs) -> list[Line3D] | FacetGrid[DataArray]: return dataarray_plot.step(self._da, *args, **kwargs) @overload - def scatter( # type: ignore[misc] # None is hashable :( + def scatter( # type: ignore[misc,unused-ignore] # None is hashable :( self, *args: Any, x: Hashable | None = None, @@ -309,7 +309,7 @@ def scatter(self, *args, **kwargs) -> PathCollection | FacetGrid[DataArray]: return dataarray_plot.scatter(self._da, *args, **kwargs) @overload - def imshow( # type: ignore[misc] # None is hashable :( + def imshow( # type: ignore[misc,unused-ignore] # None is hashable :( self, *args: Any, x: Hashable | None = None, @@ -433,7 +433,7 @@ def imshow(self, *args, **kwargs) -> AxesImage | FacetGrid[DataArray]: return dataarray_plot.imshow(self._da, *args, **kwargs) @overload - def contour( # type: ignore[misc] # None is hashable :( + def contour( # type: ignore[misc,unused-ignore] # None is hashable :( self, *args: Any, x: Hashable | None = None, @@ -557,7 +557,7 @@ def contour(self, *args, **kwargs) -> QuadContourSet | FacetGrid[DataArray]: return dataarray_plot.contour(self._da, *args, **kwargs) @overload - def contourf( # type: ignore[misc] # None is hashable :( + def contourf( # type: ignore[misc,unused-ignore] # None is hashable :( self, *args: Any, x: Hashable | None = None, @@ -681,7 +681,7 @@ def contourf(self, *args, **kwargs) -> QuadContourSet | FacetGrid[DataArray]: return dataarray_plot.contourf(self._da, *args, **kwargs) @overload - def pcolormesh( # type: ignore[misc] # None is hashable :( + def pcolormesh( # type: ignore[misc,unused-ignore] # None is hashable :( self, *args: Any, x: Hashable | None = None, @@ -948,7 +948,7 @@ def __call__(self, *args, **kwargs) -> NoReturn: ) @overload - def scatter( # type: ignore[misc] # None is hashable :( + def scatter( # type: ignore[misc,unused-ignore] # None is hashable :( self, *args: Any, x: Hashable | None = None, @@ -1075,7 +1075,7 @@ def scatter(self, *args, **kwargs) -> PathCollection | FacetGrid[Dataset]: return dataset_plot.scatter(self._ds, *args, **kwargs) @overload - def quiver( # type: ignore[misc] # None is hashable :( + def quiver( # type: ignore[misc,unused-ignore] # None is hashable :( self, *args: Any, x: Hashable | None = None, @@ -1190,7 +1190,7 @@ def quiver(self, *args, **kwargs) -> Quiver | FacetGrid[Dataset]: return dataset_plot.quiver(self._ds, *args, **kwargs) @overload - def streamplot( # type: ignore[misc] # None is hashable :( + def streamplot( # type: ignore[misc,unused-ignore] # None is hashable :( self, *args: Any, x: Hashable | None = None, From f7246feecc0d90c72e537a4f3fe16b90e5d25d57 Mon Sep 17 00:00:00 2001 From: Michael Niklas Date: Tue, 12 Sep 2023 22:00:19 +0200 Subject: [PATCH 18/20] add more unused-ignore to error codes for old mpl --- xarray/plot/dataarray_plot.py | 14 +++++++------- xarray/plot/dataset_plot.py | 6 +++--- 2 files changed, 10 insertions(+), 10 deletions(-) diff --git a/xarray/plot/dataarray_plot.py b/xarray/plot/dataarray_plot.py index 67dee95d0ca..b549c7fa420 100644 --- a/xarray/plot/dataarray_plot.py +++ b/xarray/plot/dataarray_plot.py @@ -307,7 +307,7 @@ def plot( @overload -def line( # type: ignore[misc] # None is hashable :( +def line( # type: ignore[misc,unused-ignore] # None is hashable :( darray: DataArray, *args: Any, row: None = None, # no wrap -> primitive @@ -538,7 +538,7 @@ def line( @overload -def step( # type: ignore[misc] # None is hashable :( +def step( # type: ignore[misc,unused-ignore] # None is hashable :( darray: DataArray, *args: Any, where: Literal["pre", "post", "mid"] = "pre", @@ -1114,7 +1114,7 @@ def _add_labels( @overload -def scatter( # type: ignore[misc] # None is hashable :( +def scatter( # type: ignore[misc,unused-ignore] # None is hashable :( darray: DataArray, *args: Any, x: Hashable | None = None, @@ -1673,7 +1673,7 @@ def newplotfunc( @overload -def imshow( # type: ignore[misc] # None is hashable :( +def imshow( # type: ignore[misc,unused-ignore] # None is hashable :( darray: DataArray, x: Hashable | None = None, y: Hashable | None = None, @@ -1892,7 +1892,7 @@ def _center_pixels(x): @overload -def contour( # type: ignore[misc] # None is hashable :( +def contour( # type: ignore[misc,unused-ignore] # None is hashable :( darray: DataArray, x: Hashable | None = None, y: Hashable | None = None, @@ -2028,7 +2028,7 @@ def contour( @overload -def contourf( # type: ignore[misc] # None is hashable :( +def contourf( # type: ignore[misc,unused-ignore] # None is hashable :( darray: DataArray, x: Hashable | None = None, y: Hashable | None = None, @@ -2164,7 +2164,7 @@ def contourf( @overload -def pcolormesh( # type: ignore[misc] # None is hashable :( +def pcolormesh( # type: ignore[misc,unused-ignore] # None is hashable :( darray: DataArray, x: Hashable | None = None, y: Hashable | None = None, diff --git a/xarray/plot/dataset_plot.py b/xarray/plot/dataset_plot.py index 86b74105628..1ebb47d7949 100644 --- a/xarray/plot/dataset_plot.py +++ b/xarray/plot/dataset_plot.py @@ -321,7 +321,7 @@ def newplotfunc( @overload -def quiver( # type: ignore[misc] # None is hashable :( +def quiver( # type: ignore[misc,unused-ignore] # None is hashable :( ds: Dataset, *args: Any, x: Hashable | None = None, @@ -475,7 +475,7 @@ def quiver( @overload -def streamplot( # type: ignore[misc] # None is hashable :( +def streamplot( # type: ignore[misc,unused-ignore] # None is hashable :( ds: Dataset, *args: Any, x: Hashable | None = None, @@ -749,7 +749,7 @@ def _temp_dataarray(ds: Dataset, y: Hashable, locals_: dict[str, Any]) -> DataAr @overload -def scatter( # type: ignore[misc] # None is hashable :( +def scatter( # type: ignore[misc,unused-ignore] # None is hashable :( ds: Dataset, *args: Any, x: Hashable | None = None, From 33c6f207bafe734d43a4d317015b358de9ab4956 Mon Sep 17 00:00:00 2001 From: Michael Niklas Date: Tue, 12 Sep 2023 22:04:33 +0200 Subject: [PATCH 19/20] replace type: ignore[attr-defined] with assert hasattr --- xarray/tests/test_plot.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/xarray/tests/test_plot.py b/xarray/tests/test_plot.py index d1ee7aadc53..b0e6ff90bc7 100644 --- a/xarray/tests/test_plot.py +++ b/xarray/tests/test_plot.py @@ -1766,7 +1766,8 @@ def test_colors_np_levels(self) -> None: assert self._color_as_tuple(colors[1]) == (1.0, 0.0, 0.0) assert self._color_as_tuple(colors[2]) == (1.0, 1.0, 1.0) # the last color is now under "over" - assert self._color_as_tuple(cmap._rgba_over) == (0.0, 0.0, 1.0) # type: ignore[attr-defined] + assert hasattr(cmap, "_rgba_over") + assert self._color_as_tuple(cmap._rgba_over) == (0.0, 0.0, 1.0) def test_cmap_and_color_both(self) -> None: with pytest.raises(ValueError): @@ -2160,8 +2161,10 @@ def test_colorbar_scatter(self) -> None: fg: xplt.FacetGrid = ds.plot.scatter(x="a", y="a", row="x", hue="a") cbar = fg.cbar assert cbar is not None - assert cbar.vmin == 0 # type: ignore[attr-defined] - assert cbar.vmax == 3 # type: ignore[attr-defined] + assert hasattr(cbar, "vmin") + assert cbar.vmin == 0 + assert hasattr(cbar, "vmax") + assert cbar.vmax == 3 @pytest.mark.slow def test_empty_cell(self) -> None: @@ -2785,7 +2788,8 @@ def test_non_numeric_legend(self) -> None: axes = pc.axes assert axes is not None # should make a discrete legend - assert axes.legend_ is not None # type:ignore[attr-defined] + assert hasattr(axes, "legend_") + assert axes.legend_ is not None def test_legend_labels(self) -> None: # regression test for #4126: incorrect legend labels From b1d721c0e5362b20c333c43546c0f3d9c9a60d35 Mon Sep 17 00:00:00 2001 From: Michael Niklas Date: Sat, 16 Sep 2023 21:22:27 +0200 Subject: [PATCH 20/20] apply code review suggestions --- xarray/plot/accessor.py | 22 +++++++++++----------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/xarray/plot/accessor.py b/xarray/plot/accessor.py index d842ef45ab3..203bae2691f 100644 --- a/xarray/plot/accessor.py +++ b/xarray/plot/accessor.py @@ -134,7 +134,7 @@ def line( ) -> FacetGrid[DataArray]: ... - @functools.wraps(dataarray_plot.line, assigned=("__doc__")) + @functools.wraps(dataarray_plot.line, assigned=("__doc__",)) def line(self, *args, **kwargs) -> list[Line3D] | FacetGrid[DataArray]: return dataarray_plot.line(self._da, *args, **kwargs) @@ -177,7 +177,7 @@ def step( ) -> FacetGrid[DataArray]: ... - @functools.wraps(dataarray_plot.step, assigned=("__doc__")) + @functools.wraps(dataarray_plot.step, assigned=("__doc__",)) def step(self, *args, **kwargs) -> list[Line3D] | FacetGrid[DataArray]: return dataarray_plot.step(self._da, *args, **kwargs) @@ -304,7 +304,7 @@ def scatter( ) -> FacetGrid[DataArray]: ... - @functools.wraps(dataarray_plot.scatter, assigned=("__doc__")) + @functools.wraps(dataarray_plot.scatter, assigned=("__doc__",)) def scatter(self, *args, **kwargs) -> PathCollection | FacetGrid[DataArray]: return dataarray_plot.scatter(self._da, *args, **kwargs) @@ -428,7 +428,7 @@ def imshow( ) -> FacetGrid[DataArray]: ... - @functools.wraps(dataarray_plot.imshow, assigned=("__doc__")) + @functools.wraps(dataarray_plot.imshow, assigned=("__doc__",)) def imshow(self, *args, **kwargs) -> AxesImage | FacetGrid[DataArray]: return dataarray_plot.imshow(self._da, *args, **kwargs) @@ -552,7 +552,7 @@ def contour( ) -> FacetGrid[DataArray]: ... - @functools.wraps(dataarray_plot.contour, assigned=("__doc__")) + @functools.wraps(dataarray_plot.contour, assigned=("__doc__",)) def contour(self, *args, **kwargs) -> QuadContourSet | FacetGrid[DataArray]: return dataarray_plot.contour(self._da, *args, **kwargs) @@ -676,7 +676,7 @@ def contourf( ) -> FacetGrid: ... - @functools.wraps(dataarray_plot.contourf, assigned=("__doc__")) + @functools.wraps(dataarray_plot.contourf, assigned=("__doc__",)) def contourf(self, *args, **kwargs) -> QuadContourSet | FacetGrid[DataArray]: return dataarray_plot.contourf(self._da, *args, **kwargs) @@ -800,7 +800,7 @@ def pcolormesh( ) -> FacetGrid[DataArray]: ... - @functools.wraps(dataarray_plot.pcolormesh, assigned=("__doc__")) + @functools.wraps(dataarray_plot.pcolormesh, assigned=("__doc__",)) def pcolormesh(self, *args, **kwargs) -> QuadMesh | FacetGrid[DataArray]: return dataarray_plot.pcolormesh(self._da, *args, **kwargs) @@ -924,7 +924,7 @@ def surface( ) -> FacetGrid: ... - @functools.wraps(dataarray_plot.surface, assigned=("__doc__")) + @functools.wraps(dataarray_plot.surface, assigned=("__doc__",)) def surface(self, *args, **kwargs) -> Poly3DCollection: return dataarray_plot.surface(self._da, *args, **kwargs) @@ -1070,7 +1070,7 @@ def scatter( ) -> FacetGrid[Dataset]: ... - @functools.wraps(dataset_plot.scatter, assigned=("__doc__")) + @functools.wraps(dataset_plot.scatter, assigned=("__doc__",)) def scatter(self, *args, **kwargs) -> PathCollection | FacetGrid[Dataset]: return dataset_plot.scatter(self._ds, *args, **kwargs) @@ -1185,7 +1185,7 @@ def quiver( ) -> FacetGrid[Dataset]: ... - @functools.wraps(dataset_plot.quiver, assigned=("__doc__")) + @functools.wraps(dataset_plot.quiver, assigned=("__doc__",)) def quiver(self, *args, **kwargs) -> Quiver | FacetGrid[Dataset]: return dataset_plot.quiver(self._ds, *args, **kwargs) @@ -1300,6 +1300,6 @@ def streamplot( ) -> FacetGrid[Dataset]: ... - @functools.wraps(dataset_plot.streamplot, assigned=("__doc__")) + @functools.wraps(dataset_plot.streamplot, assigned=("__doc__",)) def streamplot(self, *args, **kwargs) -> LineCollection | FacetGrid[Dataset]: return dataset_plot.streamplot(self._ds, *args, **kwargs)