Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add legend_labels option to allow overriding legend labels #5342

Merged
merged 5 commits into from
Jun 24, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 0 additions & 36 deletions holoviews/operation/datashader.py
Original file line number Diff line number Diff line change
Expand Up @@ -2083,39 +2083,3 @@ def _sort_by_distance(cls, raster, df, x, y):
Points: inspect_points,
Polygons: inspect_polygons
}


class categorical_legend(Operation):

def _process(self, element, key=None):
from ..plotting.util import rgb2hex
rasterize_op = element.pipeline.find(rasterize)
if isinstance(rasterize_op, datashade):
shade_op = rasterize_op
else:
shade_op = element.pipeline.find(shade)
if None in (shade_op, rasterize_op):
return None
hvds = element.dataset
input_el = element.pipeline.operations[0](hvds)
agg = rasterize_op._get_aggregator(input_el, rasterize_op.aggregator)
if not isinstance(agg, (ds.count_cat, ds.by)):
return
column = agg.column
if hasattr(hvds.data, 'dtypes'):
cats = list(hvds.data.dtypes[column].categories)
if cats == ['__UNKNOWN_CATEGORIES__']:
cats = list(hvds.data[column].cat.as_known().categories)
else:
cats = list(hvds.dimension_values(column, expanded=False))
colors = shade_op.color_key
color_data = [(0, 0, cat) for cat in cats]
if isinstance(colors, list):
cat_colors = {cat: colors[i] for i, cat in enumerate(cats)}
else:
cat_colors = {cat: colors[cat] for cat in cats}
cat_colors = {
cat: rgb2hex([v/256 for v in color[:3]]) if isinstance(color, tuple) else color
for cat, color in cat_colors.items()}
return Points(color_data, vdims=['category']).opts(
apply_ranges=False, cmap=cat_colors, color='category', show_legend=True)
61 changes: 37 additions & 24 deletions holoviews/plotting/bokeh/element.py
Original file line number Diff line number Diff line change
Expand Up @@ -1134,19 +1134,25 @@ def _apply_transforms(self, element, data, ranges, style, group=None):
cmapper = self._get_colormapper(v, element, ranges,
dict(style), name=k+'_color_mapper',
group=group, **kwargs)
field = k
categorical = isinstance(cmapper, CategoricalColorMapper)
if categorical and val.dtype.kind in 'ifMub':
if v.dimension in element:
formatter = element.get_dimension(v.dimension).pprint_value
else:
formatter = str
field = k + '_str__'
data[k+'_str__'] = [formatter(d) for d in val]
else:
field = k
if categorical and getattr(self, 'show_legend', False):
legend_prop = 'legend_field' if bokeh_version >= LooseVersion('1.3.5') else 'legend'
new_style[legend_prop] = field

if categorical:
if val.dtype.kind in 'ifMub':
field = k + '_str__'
if v.dimension in element:
formatter = element.get_dimension(v.dimension).pprint_value
else:
formatter = str
data[field] = [formatter(d) for d in val]
if getattr(self, 'show_legend', False):
legend_labels = getattr(self, 'legend_labels', False)
if legend_labels:
label_field = f'_{field}_labels'
data[label_field] = [legend_labels.get(v, v) for v in val]
new_style['legend_field'] = label_field
else:
new_style['legend_field'] = field
key = {'field': field, 'transform': cmapper}
new_style[k] = key

Expand Down Expand Up @@ -2031,6 +2037,19 @@ def _init_glyph(self, plot, mapping, properties):

class LegendPlot(ElementPlot):

legend_cols = param.Integer(default=False, doc="""
Whether to lay out the legend as columns.""")

legend_labels = param.Dict(default=None, doc="""
Label overrides.""")

legend_muted = param.Boolean(default=False, doc="""
Controls whether the legend entries are muted by default.""")

legend_offset = param.NumericTuple(default=(0, 0), doc="""
If legend is placed outside the axis, this determines the
(width, height) offset in pixels from the original position.""")

legend_position = param.ObjectSelector(objects=["top_right",
"top_left",
"bottom_left",
Expand All @@ -2043,21 +2062,12 @@ class LegendPlot(ElementPlot):
options. The predefined options may be customized in the
legend_specs class attribute.""")

legend_muted = param.Boolean(default=False, doc="""
Controls whether the legend entries are muted by default.""")

legend_offset = param.NumericTuple(default=(0, 0), doc="""
If legend is placed outside the axis, this determines the
(width, height) offset in pixels from the original position.""")

legend_cols = param.Integer(default=False, doc="""
Whether to lay out the legend as columns.""")

legend_opts = param.Dict(default={}, doc="""
Allows setting specific styling options for the colorbar.""")

legend_specs = {'right': 'right', 'left': 'left', 'top': 'above',
'bottom': 'below'}
legend_specs = {
'right': 'right', 'left': 'left', 'top': 'above', 'bottom': 'below'
}

def _process_legend(self, plot=None):
plot = plot or self.handles['plot']
Expand Down Expand Up @@ -2208,6 +2218,9 @@ def _process_legend(self, overlay):
if (item in filtered or not item.renderers or
not any(r.visible or 'hv_legend' in r.tags for r in item.renderers)):
continue
if isinstance(item.label, dict) and 'value' in item.label and self.legend_labels:
label = item.label['value']
item.label = {'value': self.legend_labels.get(label, label)}
renderers += item.renderers
filtered.append(item)
legend.items[:] = list(util.unique_iterator(filtered))
Expand Down
2 changes: 1 addition & 1 deletion holoviews/plotting/bokeh/raster.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ def _hover_opts(self, element):

def _init_glyphs(self, plot, element, ranges, source):
super(RGBPlot, self)._init_glyphs(plot, element, ranges, source)
if 'holoviews.operation.datashader' not in sys.modules or not self.show_legend:
if not ('holoviews.operation.datashader' in sys.modules and self.show_legend):
return
try:
legend = categorical_legend(element, backend=self.backend)
Expand Down
12 changes: 10 additions & 2 deletions holoviews/plotting/mpl/element.py
Original file line number Diff line number Diff line change
Expand Up @@ -483,8 +483,9 @@ def update_frame(self, key, ranges=None, element=None):
style = self.lookup_options(element, 'style')
self.style = style.max_cycles(max_cycles) if max_cycles else style

labels = getattr(self, 'legend_labels', {})
label = element.label if self.show_legend else ''
style = dict(label=label, zorder=self.zorder, **self.style[self.cyclic_index])
style = dict(label=labels.get(label, label), zorder=self.zorder, **self.style[self.cyclic_index])
axis_kwargs = self.update_handles(key, axis, element, ranges, style)
self._finalize_axis(key, element=element, ranges=ranges,
**(axis_kwargs if axis_kwargs else {}))
Expand Down Expand Up @@ -615,6 +616,8 @@ def _apply_transforms(self, element, ranges, style):
else:
factors = util.unique_array(val)
val = util.search_indices(val, factors)
labels = getattr(self, 'legend_labels', {})
factors = [labels.get(f, f) for f in factors]
new_style['cat_legend'] = {
'title': v.dimension, 'prop': 'c', 'factors': factors
}
Expand Down Expand Up @@ -1010,6 +1013,9 @@ class LegendPlot(ElementPlot):
legend_cols = param.Integer(default=None, doc="""
Number of legend columns in the legend.""")

legend_labels = param.Dict(default={}, doc="""
A mapping that allows overriding legend labels.""")

legend_position = param.ObjectSelector(objects=['inner', 'right',
'bottom', 'top',
'left', 'best',
Expand Down Expand Up @@ -1049,6 +1055,7 @@ def _legend_opts(self):
legend_opts.update(**dict(leg_spec, **self._fontsize('legend')))
return legend_opts


class OverlayPlot(LegendPlot, GenericOverlayPlot):
"""
OverlayPlot supports compositors processing of Overlays across maps.
Expand Down Expand Up @@ -1084,6 +1091,7 @@ def _adjust_legend(self, overlay, axis):
legend_plot = True
dimensions = overlay.kdims
title = ', '.join([d.label for d in dimensions])
labels = self.legend_labels
for key, subplot in self.subplots.items():
element = overlay.data.get(key, False)
if not subplot.show_legend or not element: continue
Expand All @@ -1101,7 +1109,7 @@ def _adjust_legend(self, overlay, axis):
if isinstance(subplot, OverlayPlot):
legend_data += subplot.handles.get('legend_data', {}).items()
elif element.label and handle:
legend_data.append((handle, element.label))
legend_data.append((handle, labels.get(element.label, element.label)))
all_handles, all_labels = list(zip(*legend_data)) if legend_data else ([], [])
data = OrderedDict()
used_labels = []
Expand Down
1 change: 1 addition & 0 deletions holoviews/plotting/mpl/raster.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,6 +234,7 @@ class RasterGridPlot(GridPlot, OverlayPlot):
invert_zaxis = param.Parameter(precedence=-1)
labelled = param.Parameter(precedence=-1)
legend_cols = param.Parameter(precedence=-1)
legend_labels = param.Parameter(precedence=-1)
legend_position = param.Parameter(precedence=-1)
legend_opts = param.Parameter(precedence=-1)
legend_limit = param.Parameter(precedence=-1)
Expand Down
4 changes: 4 additions & 0 deletions holoviews/plotting/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -1288,6 +1288,10 @@ def _process(self, element, key=None):


class categorical_legend(Operation):
"""
Generates a Points element which contains information for generating
a legend by inspecting the pipeline of a datashaded RGB element.
"""

backend = param.String()

Expand Down
7 changes: 7 additions & 0 deletions holoviews/tests/plotting/bokeh/test_overlayplot.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,6 +280,13 @@ def test_overlay_legend(self):
legend_labels = [l.label['value'] for l in plot.state.legend[0].items]
self.assertEqual(legend_labels, ['A', 'B'])

def test_overlay_legend_with_labels(self):
overlay = (Curve(range(10), label='A') * Curve(range(10), label='B')).opts(
legend_labels={'A': 'A Curve', 'B': 'B Curve'})
plot = bokeh_renderer.get_plot(overlay)
legend_labels = [l.label['value'] for l in plot.state.legend[0].items]
self.assertEqual(legend_labels, ['A Curve', 'B Curve'])

def test_dynamic_subplot_remapping(self):
# Checks that a plot is appropriately updated when reused
def cb(X):
Expand Down
18 changes: 18 additions & 0 deletions holoviews/tests/plotting/bokeh/test_pathplot.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,24 @@ def test_path_continuously_varying_color_legend(self):
self.assertEqual(item.label, legend)
self.assertEqual(item.renderers, [plot.handles['glyph_renderer']])

def test_path_continuously_varying_color_legend_with_labels(self):
data = {
"x": [1,2,3,4,5,6,7,8,9],
"y": [1,2,3,4,5,6,7,8,9],
"cat": [0,1,2,0,1,2,0,1,2]
}

colors = ["#FF0000", "#00FF00", "#0000FF"]
levels=[0,1,2,3]

path = Path(data, vdims="cat").opts(color="cat", cmap=dict(zip(levels, colors)), line_width=4, show_legend=True, legend_labels={0: 'A', 1: 'B', 2: 'C'})
plot = bokeh_renderer.get_plot(path)
cds = plot.handles['cds']
item = plot.state.legend[0].items[0]
legend = {'field': '_color_str___labels'}
self.assertEqual(cds.data['_color_str___labels'], ['A', 'B', 'C', 'A', 'B', 'C', 'A', 'B'])
self.assertEqual(item.label, legend)
self.assertEqual(item.renderers, [plot.handles['glyph_renderer']])


class TestPolygonPlot(TestBokehPlot):
Expand Down
10 changes: 10 additions & 0 deletions holoviews/tests/plotting/bokeh/test_pointplot.py
Original file line number Diff line number Diff line change
Expand Up @@ -374,6 +374,16 @@ def test_point_categorical_color_op(self):
self.assertEqual(glyph.fill_color, {'field': 'color', 'transform': cmapper})
self.assertEqual(glyph.line_color, {'field': 'color', 'transform': cmapper})

def test_point_categorical_color_op_legend_with_labels(self):
labels = {'A': 'A point', 'B': 'B point', 'C': 'C point'}
points = Points([(0, 0, 'A'), (0, 1, 'B'), (0, 2, 'C')],
vdims='color').opts(color='color', show_legend=True, legend_labels=labels)
plot = bokeh_renderer.get_plot(points)
cds = plot.handles['cds']
legend = plot.state.legend[0].items[0]
assert legend.label == {'field': '_color_labels'}
assert cds.data['_color_labels'] == ['A point', 'B point', 'C point']

def test_point_categorical_dtype_color_op(self):
df = pd.DataFrame(dict(sample_id=['subject 1', 'subject 2', 'subject 3', 'subject 4'], category=['apple', 'pear', 'apple', 'pear'], value=[1, 2, 3, 4]))
df['category'] = df['category'].astype('category')
Expand Down
19 changes: 19 additions & 0 deletions holoviews/tests/plotting/matplotlib/test_overlayplot.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,3 +99,22 @@ def test_overlay_ylabel_override(self):
overlay = (Curve(range(10)).options(ylabel='custom y-label') * Curve(range(10)))
axes = mpl_renderer.get_plot(overlay).handles['axis']
self.assertEqual(axes.get_ylabel(), 'custom y-label')



class TestLegends(TestMPLPlot):

def test_overlay_legend(self):
overlay = Curve(range(10), label='A') * Curve(range(10), label='B')
plot = mpl_renderer.get_plot(overlay)
legend = plot.handles['legend']
legend_labels = [l.get_text() for l in legend.texts]
self.assertEqual(legend_labels, ['A', 'B'])

def test_overlay_legend_with_labels(self):
overlay = (Curve(range(10), label='A') * Curve(range(10), label='B')).opts(
legend_labels={'A': 'A Curve', 'B': 'B Curve'})
plot = mpl_renderer.get_plot(overlay)
legend = plot.handles['legend']
legend_labels = [l.get_text() for l in legend.texts]
self.assertEqual(legend_labels, ['A Curve', 'B Curve'])
16 changes: 16 additions & 0 deletions holoviews/tests/plotting/matplotlib/test_pointplot.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,6 +237,22 @@ def test_point_categorical_color_op(self):
self.assertEqual(np.asarray(artist.get_array()), np.array([0, 1, 0]))
self.assertEqual(artist.get_clim(), (0, 1))

def test_point_categorical_color_op_legend(self):
points = Points([(0, 0, 'A'), (0, 1, 'B'), (0, 2, 'A')],
vdims='color').options(color='color', show_legend=True)
plot = mpl_renderer.get_plot(points)
leg = plot.handles['axis'].get_legend()
legend_labels = [l.get_text() for l in leg.texts]
self.assertEqual(legend_labels, ['A', 'B'])

def test_point_categorical_color_op_legend_with_labels(self):
points = Points([(0, 0, 'A'), (0, 1, 'B'), (0, 2, 'A')], vdims='color').opts(
color='color', show_legend=True, legend_labels={'A': 'A point', 'B': 'B point'})
plot = mpl_renderer.get_plot(points)
leg = plot.handles['axis'].get_legend()
legend_labels = [l.get_text() for l in leg.texts]
self.assertEqual(legend_labels, ['A point', 'B point'])

def test_point_size_op(self):
points = Points([(0, 0, 1), (0, 1, 4), (0, 2, 8)],
vdims='size').options(s='size')
Expand Down