diff --git a/doc/whats-new.rst b/doc/whats-new.rst index 398c332433f..db10ec653c5 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -64,6 +64,8 @@ Bug fixes By `Richard Kleijn `_ . - Remove dictionary unpacking when using ``.loc`` to avoid collision with ``.sel`` parameters (:pull:`4695`). By `Anderson Banihirwe `_ +- Fix the legend created by :py:meth:`Dataset.plot.scatter` (:issue:`4641`, :pull:`4723`). + By `Justus Magin `_. - Fix a crash in orthogonal indexing on geographic coordinates with ``engine='cfgrib'`` (:issue:`4733` :pull:`4737`). By `Alessandro Amici `_ - Coordinates with dtype ``str`` or ``bytes`` now retain their dtype on many operations, diff --git a/xarray/plot/dataset_plot.py b/xarray/plot/dataset_plot.py index 7ba0f93f33a..6d942e1b0fa 100644 --- a/xarray/plot/dataset_plot.py +++ b/xarray/plot/dataset_plot.py @@ -291,7 +291,7 @@ def newplotfunc( allargs = locals().copy() allargs["plotfunc"] = globals()[plotfunc.__name__] allargs["data"] = ds - # TODO dcherian: why do I need to remove kwargs? + # remove kwargs to avoid passing the information twice for arg in ["meta_data", "kwargs", "ds"]: del allargs[arg] @@ -422,7 +422,10 @@ def scatter(ds, x, y, ax, **kwargs): if hue_style == "discrete": primitive = [] - for label in np.unique(data["hue"].values): + # use pd.unique instead of np.unique because that keeps the order of the labels, + # which is important to keep them in sync with the ones used in + # FacetGrid.add_legend + for label in pd.unique(data["hue"].values.ravel()): mask = data["hue"] == label if data["sizes"] is not None: kwargs.update(s=data["sizes"].where(mask, drop=True).values.flatten()) diff --git a/xarray/tests/test_plot.py b/xarray/tests/test_plot.py index 471bbb7051e..47b15446f1d 100644 --- a/xarray/tests/test_plot.py +++ b/xarray/tests/test_plot.py @@ -2290,6 +2290,17 @@ def test_legend_labels(self): lines = ds2.plot.scatter(x="A", y="B", hue="hue") assert [t.get_text() for t in lines[0].axes.get_legend().texts] == ["a", "b"] + def test_legend_labels_facetgrid(self): + ds2 = self.ds.copy() + ds2["hue"] = ["d", "a", "c", "b"] + g = ds2.plot.scatter(x="A", y="B", hue="hue", col="col") + legend_labels = tuple(t.get_text() for t in g.figlegend.texts) + attached_labels = [ + tuple(m.get_label() for m in mappables_per_ax) + for mappables_per_ax in g._mappables + ] + assert list(set(attached_labels)) == [legend_labels] + def test_add_legend_by_default(self): sc = self.ds.plot.scatter(x="A", y="B", hue="hue") assert len(sc.figure.axes) == 2