Skip to content

Commit

Permalink
add legend as dict for line plots (#71)
Browse files Browse the repository at this point in the history
  • Loading branch information
gidden authored and danielhuppmann committed Jul 17, 2018
1 parent bae49ad commit 0877452
Show file tree
Hide file tree
Showing 4 changed files with 37 additions and 16 deletions.
1 change: 1 addition & 0 deletions RELEASE_NOTES.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
# Next Release

- (#73)[https://github.com/IAMconsortium/pyam/pull/73] Adds ability to remove labels for markers, colors, or linestyles
- (#71)[https://github.com/IAMconsortium/pyam/pull/71] Line plots `legend` keyword can now be a dictionary of legend arguments
- (#70)[https://github.com/IAMconsortium/pyam/pull/70] Support reading of both SSP and RCP data files downloaded from the IIASA database.
- (#66)[https://github.com/IAMconsortium/pyam/pull/66] Fixes a bug in the `interpolate()` function (duplication of data points if already defined)
- (#65)[https://github.com/IAMconsortium/pyam/pull/65] Add a `filter_by_meta()` function to filter/join a pd.DataFrame with an IamDataFrame.meta table
43 changes: 27 additions & 16 deletions pyam/plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
except ImportError:
from functools32 import lru_cache

from pyam.logger import logger
from pyam.run_control import run_control
from pyam.utils import requires_package, SORT_IDX, isstr

Expand All @@ -35,6 +36,9 @@
# explicitly declared
_DEFAULT_PROPS = None

# maximum number of labels after which do not show legends by default
MAX_LEGEND_LABELS = 13


def reset_default_props(**kwargs):
"""Reset properties to initial cycle point"""
Expand Down Expand Up @@ -205,13 +209,13 @@ def region_plot(df, column='value', ax=None, crs=None, gdf=None, add_features=Tr
)
cb = plt.colorbar(scalar_map, **cbar)

if legend:
if legend is not False:
if legend is True: # use some defaults
legend = dict(
bbox_to_anchor=(1.32, 0.5) if cbar else (1.2, 0.5),
loc='right',
)
ax.legend(handles, labels, **legend)
_add_legend(ax, handles, labels, legend)

if title:
var = df['variable'].unique()[0]
Expand Down Expand Up @@ -471,9 +475,10 @@ def line_plot(df, x='year', y='value', ax=None, legend=None, title=True,
The column to use for y-axis values
default: value
ax : matplotlib.Axes, optional
legend : bool, optional
Include a legend (`None` displays legend only if less than 13 entries)
default: None
legend : bool or dictionary, optional
Add a legend. If a dictionary is provided, it will be used as keyword
arguments in creating the legend.
default: None (displays legend only if less than 13 entries)
title : bool or string, optional
Display a default or custom title.
color : string, optional
Expand All @@ -496,7 +501,6 @@ def line_plot(df, x='year', y='value', ax=None, legend=None, title=True,
default: []
kwargs : Additional arguments to pass to the pd.DataFrame.plot() function
"""

if ax is None:
fig, ax = plt.subplots()

Expand Down Expand Up @@ -524,11 +528,11 @@ def line_plot(df, x='year', y='value', ax=None, legend=None, title=True,
prop_idx[kind] = df.columns.names.index(var)

# plot data, keeping track of which legend labels to apply
legend_data = []
no_label = [rm_legend_label] if isstr(rm_legend_label) else rm_legend_label
for col, data in df.iteritems():
pargs = {}
labels = []
# build plotting args and line legend labels
for key, kind, var in [('c', 'color', color),
('marker', 'marker', marker),
('linestyle', 'linestyle', linestyle)]:
Expand All @@ -539,19 +543,18 @@ def line_plot(df, x='year', y='value', ax=None, legend=None, title=True,
labels.append(repr(label).lstrip("u'").strip("'"))
else:
pargs[key] = var

legend_data.append(' '.join(labels))
kwargs.update(pargs)
data.plot(ax=ax, **kwargs)
if labels:
ax.lines[-1].set_label(' '.join(labels))

# build legend handles and labels
# build unique legend handles and labels
handles, labels = ax.get_legend_handles_labels()
if legend_data != [''] * len(legend_data):
labels = sorted(list(set(tuple(legend_data))))
idxs = [legend_data.index(d) for d in labels]
handles = [handles[i] for i in idxs]
if legend is None and len(labels) < 13 or legend is True:
ax.legend(handles, labels)
handles, labels = np.array(handles), np.array(labels)
_, idx = np.unique(labels, return_index=True)
handles, labels = handles[idx], labels[idx]
if legend is not False:
_add_legend(ax, handles, labels, legend)

# add default labels if possible
ax.set_xlabel(x.title())
Expand All @@ -571,3 +574,11 @@ def line_plot(df, x='year', y='value', ax=None, legend=None, title=True,
ax.set_title(' '.join(_title))

return ax, handles, labels


def _add_legend(ax, handles, labels, legend):
if legend is None and len(labels) >= MAX_LEGEND_LABELS:
logger().info('>={} labels, not applying legend'.format(
MAX_LEGEND_LABELS))
legend = {} if legend in [True, None] else legend
ax.legend(handles, labels, **legend)
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
9 changes: 9 additions & 0 deletions tests/test_plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,15 @@ def test_line_plot(plot_df):
return fig


@pytest.mark.skipif(IS_WINDOWS, reason=WINDOWS_REASON)
@pytest.mark.mpl_image_compare(**MPL_KWARGS)
def test_line_plot_dict_legend(plot_df):
fig, ax = plt.subplots(figsize=(8, 8))
plot_df.line_plot(ax=ax, legend=dict(
loc='center left', bbox_to_anchor=(1.0, 0.5)))
return fig


@pytest.mark.skipif(IS_WINDOWS, reason=WINDOWS_REASON)
@pytest.mark.mpl_image_compare(**MPL_KWARGS)
def test_line_no_legend(plot_df):
Expand Down

0 comments on commit 0877452

Please sign in to comment.