Skip to content

Commit

Permalink
Add legend_labels option to allow overriding legend labels (#5342)
Browse files Browse the repository at this point in the history
* Implement legend_labels in bokeh

* Implement legend_labels in mpl

* Fix mpl OverlayPlot legend_labels

* Add tests

* Fix flakes
  • Loading branch information
philippjfr authored Jun 24, 2022
1 parent 3d3154a commit d4a2de2
Show file tree
Hide file tree
Showing 11 changed files with 123 additions and 63 deletions.
36 changes: 0 additions & 36 deletions holoviews/operation/datashader.py
Original file line number Diff line number Diff line change
Expand Up @@ -2083,39 +2083,3 @@ def _sort_by_distance(cls, raster, df, x, y):
Points: inspect_points,
Polygons: inspect_polygons
}


class categorical_legend(Operation):

def _process(self, element, key=None):
from ..plotting.util import rgb2hex
rasterize_op = element.pipeline.find(rasterize)
if isinstance(rasterize_op, datashade):
shade_op = rasterize_op
else:
shade_op = element.pipeline.find(shade)
if None in (shade_op, rasterize_op):
return None
hvds = element.dataset
input_el = element.pipeline.operations[0](hvds)
agg = rasterize_op._get_aggregator(input_el, rasterize_op.aggregator)
if not isinstance(agg, (ds.count_cat, ds.by)):
return
column = agg.column
if hasattr(hvds.data, 'dtypes'):
cats = list(hvds.data.dtypes[column].categories)
if cats == ['__UNKNOWN_CATEGORIES__']:
cats = list(hvds.data[column].cat.as_known().categories)
else:
cats = list(hvds.dimension_values(column, expanded=False))
colors = shade_op.color_key
color_data = [(0, 0, cat) for cat in cats]
if isinstance(colors, list):
cat_colors = {cat: colors[i] for i, cat in enumerate(cats)}
else:
cat_colors = {cat: colors[cat] for cat in cats}
cat_colors = {
cat: rgb2hex([v/256 for v in color[:3]]) if isinstance(color, tuple) else color
for cat, color in cat_colors.items()}
return Points(color_data, vdims=['category']).opts(
apply_ranges=False, cmap=cat_colors, color='category', show_legend=True)
61 changes: 37 additions & 24 deletions holoviews/plotting/bokeh/element.py
Original file line number Diff line number Diff line change
Expand Up @@ -1134,19 +1134,25 @@ def _apply_transforms(self, element, data, ranges, style, group=None):
cmapper = self._get_colormapper(v, element, ranges,
dict(style), name=k+'_color_mapper',
group=group, **kwargs)
field = k
categorical = isinstance(cmapper, CategoricalColorMapper)
if categorical and val.dtype.kind in 'ifMub':
if v.dimension in element:
formatter = element.get_dimension(v.dimension).pprint_value
else:
formatter = str
field = k + '_str__'
data[k+'_str__'] = [formatter(d) for d in val]
else:
field = k
if categorical and getattr(self, 'show_legend', False):
legend_prop = 'legend_field' if bokeh_version >= LooseVersion('1.3.5') else 'legend'
new_style[legend_prop] = field

if categorical:
if val.dtype.kind in 'ifMub':
field = k + '_str__'
if v.dimension in element:
formatter = element.get_dimension(v.dimension).pprint_value
else:
formatter = str
data[field] = [formatter(d) for d in val]
if getattr(self, 'show_legend', False):
legend_labels = getattr(self, 'legend_labels', False)
if legend_labels:
label_field = f'_{field}_labels'
data[label_field] = [legend_labels.get(v, v) for v in val]
new_style['legend_field'] = label_field
else:
new_style['legend_field'] = field
key = {'field': field, 'transform': cmapper}
new_style[k] = key

Expand Down Expand Up @@ -2031,6 +2037,19 @@ def _init_glyph(self, plot, mapping, properties):

class LegendPlot(ElementPlot):

legend_cols = param.Integer(default=False, doc="""
Whether to lay out the legend as columns.""")

legend_labels = param.Dict(default=None, doc="""
Label overrides.""")

legend_muted = param.Boolean(default=False, doc="""
Controls whether the legend entries are muted by default.""")

legend_offset = param.NumericTuple(default=(0, 0), doc="""
If legend is placed outside the axis, this determines the
(width, height) offset in pixels from the original position.""")

legend_position = param.ObjectSelector(objects=["top_right",
"top_left",
"bottom_left",
Expand All @@ -2043,21 +2062,12 @@ class LegendPlot(ElementPlot):
options. The predefined options may be customized in the
legend_specs class attribute.""")

legend_muted = param.Boolean(default=False, doc="""
Controls whether the legend entries are muted by default.""")

legend_offset = param.NumericTuple(default=(0, 0), doc="""
If legend is placed outside the axis, this determines the
(width, height) offset in pixels from the original position.""")

legend_cols = param.Integer(default=False, doc="""
Whether to lay out the legend as columns.""")

legend_opts = param.Dict(default={}, doc="""
Allows setting specific styling options for the colorbar.""")

legend_specs = {'right': 'right', 'left': 'left', 'top': 'above',
'bottom': 'below'}
legend_specs = {
'right': 'right', 'left': 'left', 'top': 'above', 'bottom': 'below'
}

def _process_legend(self, plot=None):
plot = plot or self.handles['plot']
Expand Down Expand Up @@ -2208,6 +2218,9 @@ def _process_legend(self, overlay):
if (item in filtered or not item.renderers or
not any(r.visible or 'hv_legend' in r.tags for r in item.renderers)):
continue
if isinstance(item.label, dict) and 'value' in item.label and self.legend_labels:
label = item.label['value']
item.label = {'value': self.legend_labels.get(label, label)}
renderers += item.renderers
filtered.append(item)
legend.items[:] = list(util.unique_iterator(filtered))
Expand Down
2 changes: 1 addition & 1 deletion holoviews/plotting/bokeh/raster.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ def _hover_opts(self, element):

def _init_glyphs(self, plot, element, ranges, source):
super(RGBPlot, self)._init_glyphs(plot, element, ranges, source)
if 'holoviews.operation.datashader' not in sys.modules or not self.show_legend:
if not ('holoviews.operation.datashader' in sys.modules and self.show_legend):
return
try:
legend = categorical_legend(element, backend=self.backend)
Expand Down
12 changes: 10 additions & 2 deletions holoviews/plotting/mpl/element.py
Original file line number Diff line number Diff line change
Expand Up @@ -483,8 +483,9 @@ def update_frame(self, key, ranges=None, element=None):
style = self.lookup_options(element, 'style')
self.style = style.max_cycles(max_cycles) if max_cycles else style

labels = getattr(self, 'legend_labels', {})
label = element.label if self.show_legend else ''
style = dict(label=label, zorder=self.zorder, **self.style[self.cyclic_index])
style = dict(label=labels.get(label, label), zorder=self.zorder, **self.style[self.cyclic_index])
axis_kwargs = self.update_handles(key, axis, element, ranges, style)
self._finalize_axis(key, element=element, ranges=ranges,
**(axis_kwargs if axis_kwargs else {}))
Expand Down Expand Up @@ -615,6 +616,8 @@ def _apply_transforms(self, element, ranges, style):
else:
factors = util.unique_array(val)
val = util.search_indices(val, factors)
labels = getattr(self, 'legend_labels', {})
factors = [labels.get(f, f) for f in factors]
new_style['cat_legend'] = {
'title': v.dimension, 'prop': 'c', 'factors': factors
}
Expand Down Expand Up @@ -1010,6 +1013,9 @@ class LegendPlot(ElementPlot):
legend_cols = param.Integer(default=None, doc="""
Number of legend columns in the legend.""")

legend_labels = param.Dict(default={}, doc="""
A mapping that allows overriding legend labels.""")

legend_position = param.ObjectSelector(objects=['inner', 'right',
'bottom', 'top',
'left', 'best',
Expand Down Expand Up @@ -1049,6 +1055,7 @@ def _legend_opts(self):
legend_opts.update(**dict(leg_spec, **self._fontsize('legend')))
return legend_opts


class OverlayPlot(LegendPlot, GenericOverlayPlot):
"""
OverlayPlot supports compositors processing of Overlays across maps.
Expand Down Expand Up @@ -1084,6 +1091,7 @@ def _adjust_legend(self, overlay, axis):
legend_plot = True
dimensions = overlay.kdims
title = ', '.join([d.label for d in dimensions])
labels = self.legend_labels
for key, subplot in self.subplots.items():
element = overlay.data.get(key, False)
if not subplot.show_legend or not element: continue
Expand All @@ -1101,7 +1109,7 @@ def _adjust_legend(self, overlay, axis):
if isinstance(subplot, OverlayPlot):
legend_data += subplot.handles.get('legend_data', {}).items()
elif element.label and handle:
legend_data.append((handle, element.label))
legend_data.append((handle, labels.get(element.label, element.label)))
all_handles, all_labels = list(zip(*legend_data)) if legend_data else ([], [])
data = OrderedDict()
used_labels = []
Expand Down
1 change: 1 addition & 0 deletions holoviews/plotting/mpl/raster.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,6 +234,7 @@ class RasterGridPlot(GridPlot, OverlayPlot):
invert_zaxis = param.Parameter(precedence=-1)
labelled = param.Parameter(precedence=-1)
legend_cols = param.Parameter(precedence=-1)
legend_labels = param.Parameter(precedence=-1)
legend_position = param.Parameter(precedence=-1)
legend_opts = param.Parameter(precedence=-1)
legend_limit = param.Parameter(precedence=-1)
Expand Down
4 changes: 4 additions & 0 deletions holoviews/plotting/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -1288,6 +1288,10 @@ def _process(self, element, key=None):


class categorical_legend(Operation):
"""
Generates a Points element which contains information for generating
a legend by inspecting the pipeline of a datashaded RGB element.
"""

backend = param.String()

Expand Down
7 changes: 7 additions & 0 deletions holoviews/tests/plotting/bokeh/test_overlayplot.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,6 +280,13 @@ def test_overlay_legend(self):
legend_labels = [l.label['value'] for l in plot.state.legend[0].items]
self.assertEqual(legend_labels, ['A', 'B'])

def test_overlay_legend_with_labels(self):
overlay = (Curve(range(10), label='A') * Curve(range(10), label='B')).opts(
legend_labels={'A': 'A Curve', 'B': 'B Curve'})
plot = bokeh_renderer.get_plot(overlay)
legend_labels = [l.label['value'] for l in plot.state.legend[0].items]
self.assertEqual(legend_labels, ['A Curve', 'B Curve'])

def test_dynamic_subplot_remapping(self):
# Checks that a plot is appropriately updated when reused
def cb(X):
Expand Down
18 changes: 18 additions & 0 deletions holoviews/tests/plotting/bokeh/test_pathplot.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,24 @@ def test_path_continuously_varying_color_legend(self):
self.assertEqual(item.label, legend)
self.assertEqual(item.renderers, [plot.handles['glyph_renderer']])

def test_path_continuously_varying_color_legend_with_labels(self):
data = {
"x": [1,2,3,4,5,6,7,8,9],
"y": [1,2,3,4,5,6,7,8,9],
"cat": [0,1,2,0,1,2,0,1,2]
}

colors = ["#FF0000", "#00FF00", "#0000FF"]
levels=[0,1,2,3]

path = Path(data, vdims="cat").opts(color="cat", cmap=dict(zip(levels, colors)), line_width=4, show_legend=True, legend_labels={0: 'A', 1: 'B', 2: 'C'})
plot = bokeh_renderer.get_plot(path)
cds = plot.handles['cds']
item = plot.state.legend[0].items[0]
legend = {'field': '_color_str___labels'}
self.assertEqual(cds.data['_color_str___labels'], ['A', 'B', 'C', 'A', 'B', 'C', 'A', 'B'])
self.assertEqual(item.label, legend)
self.assertEqual(item.renderers, [plot.handles['glyph_renderer']])


class TestPolygonPlot(TestBokehPlot):
Expand Down
10 changes: 10 additions & 0 deletions holoviews/tests/plotting/bokeh/test_pointplot.py
Original file line number Diff line number Diff line change
Expand Up @@ -374,6 +374,16 @@ def test_point_categorical_color_op(self):
self.assertEqual(glyph.fill_color, {'field': 'color', 'transform': cmapper})
self.assertEqual(glyph.line_color, {'field': 'color', 'transform': cmapper})

def test_point_categorical_color_op_legend_with_labels(self):
labels = {'A': 'A point', 'B': 'B point', 'C': 'C point'}
points = Points([(0, 0, 'A'), (0, 1, 'B'), (0, 2, 'C')],
vdims='color').opts(color='color', show_legend=True, legend_labels=labels)
plot = bokeh_renderer.get_plot(points)
cds = plot.handles['cds']
legend = plot.state.legend[0].items[0]
assert legend.label == {'field': '_color_labels'}
assert cds.data['_color_labels'] == ['A point', 'B point', 'C point']

def test_point_categorical_dtype_color_op(self):
df = pd.DataFrame(dict(sample_id=['subject 1', 'subject 2', 'subject 3', 'subject 4'], category=['apple', 'pear', 'apple', 'pear'], value=[1, 2, 3, 4]))
df['category'] = df['category'].astype('category')
Expand Down
19 changes: 19 additions & 0 deletions holoviews/tests/plotting/matplotlib/test_overlayplot.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,3 +99,22 @@ def test_overlay_ylabel_override(self):
overlay = (Curve(range(10)).options(ylabel='custom y-label') * Curve(range(10)))
axes = mpl_renderer.get_plot(overlay).handles['axis']
self.assertEqual(axes.get_ylabel(), 'custom y-label')



class TestLegends(TestMPLPlot):

def test_overlay_legend(self):
overlay = Curve(range(10), label='A') * Curve(range(10), label='B')
plot = mpl_renderer.get_plot(overlay)
legend = plot.handles['legend']
legend_labels = [l.get_text() for l in legend.texts]
self.assertEqual(legend_labels, ['A', 'B'])

def test_overlay_legend_with_labels(self):
overlay = (Curve(range(10), label='A') * Curve(range(10), label='B')).opts(
legend_labels={'A': 'A Curve', 'B': 'B Curve'})
plot = mpl_renderer.get_plot(overlay)
legend = plot.handles['legend']
legend_labels = [l.get_text() for l in legend.texts]
self.assertEqual(legend_labels, ['A Curve', 'B Curve'])
16 changes: 16 additions & 0 deletions holoviews/tests/plotting/matplotlib/test_pointplot.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,6 +237,22 @@ def test_point_categorical_color_op(self):
self.assertEqual(np.asarray(artist.get_array()), np.array([0, 1, 0]))
self.assertEqual(artist.get_clim(), (0, 1))

def test_point_categorical_color_op_legend(self):
points = Points([(0, 0, 'A'), (0, 1, 'B'), (0, 2, 'A')],
vdims='color').options(color='color', show_legend=True)
plot = mpl_renderer.get_plot(points)
leg = plot.handles['axis'].get_legend()
legend_labels = [l.get_text() for l in leg.texts]
self.assertEqual(legend_labels, ['A', 'B'])

def test_point_categorical_color_op_legend_with_labels(self):
points = Points([(0, 0, 'A'), (0, 1, 'B'), (0, 2, 'A')], vdims='color').opts(
color='color', show_legend=True, legend_labels={'A': 'A point', 'B': 'B point'})
plot = mpl_renderer.get_plot(points)
leg = plot.handles['axis'].get_legend()
legend_labels = [l.get_text() for l in leg.texts]
self.assertEqual(legend_labels, ['A point', 'B point'])

def test_point_size_op(self):
points = Points([(0, 0, 1), (0, 1, 4), (0, 2, 8)],
vdims='size').options(s='size')
Expand Down

0 comments on commit d4a2de2

Please sign in to comment.