Skip to content

Commit

Permalink
scatter plot by order of the first appearance of hue (#4723)
Browse files Browse the repository at this point in the history
* plot by order of first appearance

* use ravel to avoid copying the data

* update whats-new.rst

* add a test to make sure the legend labels and the mappable labels match

* test with upstream-dev [test-upstream]

* add a comment about the reason for using pd.unique [skip-ci]

* empty commit [skip-ci]
  • Loading branch information
keewis authored Jan 13, 2021
1 parent 1ce8938 commit 747fe26
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 2 deletions.
2 changes: 2 additions & 0 deletions doc/whats-new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,8 @@ Bug fixes
By `Richard Kleijn <https://github.com/rhkleijn>`_ .
- Remove dictionary unpacking when using ``.loc`` to avoid collision with ``.sel`` parameters (:pull:`4695`).
By `Anderson Banihirwe <https://github.com/andersy005>`_
- Fix the legend created by :py:meth:`Dataset.plot.scatter` (:issue:`4641`, :pull:`4723`).
By `Justus Magin <https://github.com/keewis>`_.
- Fix a crash in orthogonal indexing on geographic coordinates with ``engine='cfgrib'`` (:issue:`4733` :pull:`4737`).
By `Alessandro Amici <https://github.com/alexamici>`_
- Coordinates with dtype ``str`` or ``bytes`` now retain their dtype on many operations,
Expand Down
7 changes: 5 additions & 2 deletions xarray/plot/dataset_plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,7 +291,7 @@ def newplotfunc(
allargs = locals().copy()
allargs["plotfunc"] = globals()[plotfunc.__name__]
allargs["data"] = ds
# TODO dcherian: why do I need to remove kwargs?
# remove kwargs to avoid passing the information twice
for arg in ["meta_data", "kwargs", "ds"]:
del allargs[arg]

Expand Down Expand Up @@ -422,7 +422,10 @@ def scatter(ds, x, y, ax, **kwargs):

if hue_style == "discrete":
primitive = []
for label in np.unique(data["hue"].values):
# use pd.unique instead of np.unique because that keeps the order of the labels,
# which is important to keep them in sync with the ones used in
# FacetGrid.add_legend
for label in pd.unique(data["hue"].values.ravel()):
mask = data["hue"] == label
if data["sizes"] is not None:
kwargs.update(s=data["sizes"].where(mask, drop=True).values.flatten())
Expand Down
11 changes: 11 additions & 0 deletions xarray/tests/test_plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -2290,6 +2290,17 @@ def test_legend_labels(self):
lines = ds2.plot.scatter(x="A", y="B", hue="hue")
assert [t.get_text() for t in lines[0].axes.get_legend().texts] == ["a", "b"]

def test_legend_labels_facetgrid(self):
ds2 = self.ds.copy()
ds2["hue"] = ["d", "a", "c", "b"]
g = ds2.plot.scatter(x="A", y="B", hue="hue", col="col")
legend_labels = tuple(t.get_text() for t in g.figlegend.texts)
attached_labels = [
tuple(m.get_label() for m in mappables_per_ax)
for mappables_per_ax in g._mappables
]
assert list(set(attached_labels)) == [legend_labels]

def test_add_legend_by_default(self):
sc = self.ds.plot.scatter(x="A", y="B", hue="hue")
assert len(sc.figure.axes) == 2
Expand Down

0 comments on commit 747fe26

Please sign in to comment.