From 9c578a27730dd5fe73a907b6dc24b02f442bc88d Mon Sep 17 00:00:00 2001
From: Philipp Rudiger
Date: Mon, 9 Oct 2017 01:15:25 +0100
Subject: [PATCH] Refactored bokeh plotting API
---
holoviews/plotting/bokeh/annotation.py | 34 +++++-----
holoviews/plotting/bokeh/chart.py | 88 +++++++++++---------------
holoviews/plotting/bokeh/element.py | 36 +++++------
holoviews/plotting/bokeh/graphs.py | 10 +--
holoviews/plotting/bokeh/path.py | 26 ++++----
holoviews/plotting/bokeh/plot.py | 2 +-
holoviews/plotting/bokeh/raster.py | 48 ++++++--------
holoviews/plotting/bokeh/tabular.py | 15 +++--
8 files changed, 113 insertions(+), 146 deletions(-)
diff --git a/holoviews/plotting/bokeh/annotation.py b/holoviews/plotting/bokeh/annotation.py
index f738b3fb26..22a0d9ab11 100644
--- a/holoviews/plotting/bokeh/annotation.py
+++ b/holoviews/plotting/bokeh/annotation.py
@@ -22,33 +22,29 @@ class TextPlot(ElementPlot):
style_opts = text_properties+['color']
_plot_methods = dict(single='text', batched='text')
- def _glyph_properties(self, plot, element, source, ranges):
- props = super(TextPlot, self)._glyph_properties(plot, element, source, ranges)
- props['text_align'] = element.halign
- props['text_baseline'] = 'middle' if element.valign == 'center' else element.valign
- if 'color' in props:
- props['text_color'] = props.pop('color')
- return props
-
- def get_data(self, element, ranges=None):
+ def get_data(self, element, ranges, style):
mapping = dict(x='x', y='y', text='text')
if self.static_source:
- return dict(x=[], y=[], text=[]), mapping
+ return dict(x=[], y=[], text=[]), mapping, style
if self.invert_axes:
data = dict(x=[element.y], y=[element.x])
else:
data = dict(x=[element.x], y=[element.y])
self._categorize_data(data, ('x', 'y'), element.dimensions())
data['text'] = [element.text]
- return (data, mapping)
+ style['text_align'] = element.halign
+ style['text_baseline'] = 'middle' if element.valign == 'center' else element.valign
+ if 'color' in style:
+ style['text_color'] = style.pop('color')
+ return (data, mapping, style)
def get_batched_data(self, element, ranges=None):
data = defaultdict(list)
for key, el in element.data.items():
- eldata, elmapping = self.get_data(el, ranges)
+ eldata, elmapping, style = self.get_data(el, ranges)
for k, eld in eldata.items():
data[k].extend(eld)
- return data, elmapping
+ return data, elmapping, style
def get_extents(self, element, ranges=None):
return None, None, None, None
@@ -63,7 +59,7 @@ class LineAnnotationPlot(ElementPlot):
_plot_methods = dict(single='Span')
- def get_data(self, element, ranges=None):
+ def get_data(self, element, ranges, style):
data, mapping = {}, {}
dim = 'width' if isinstance(element, HLine) else 'height'
if self.invert_axes:
@@ -73,7 +69,7 @@ def get_data(self, element, ranges=None):
if isinstance(loc, datetime_types):
loc = date_to_integer(loc)
mapping['location'] = loc
- return (data, mapping)
+ return (data, mapping, style)
def _init_glyph(self, plot, mapping, properties):
"""
@@ -97,7 +93,7 @@ class SplinePlot(ElementPlot):
style_opts = line_properties
_plot_methods = dict(single='bezier')
- def get_data(self, element, ranges=None):
+ def get_data(self, element, ranges, style):
if self.invert_axes:
data_attrs = ['y0', 'x0', 'cy0', 'cx0', 'cy1', 'cx1', 'y1', 'x1']
else:
@@ -117,7 +113,7 @@ def get_data(self, element, ranges=None):
self.warning('Bokeh SplitPlot only support cubic splines, '
'unsupported splines were skipped during plotting.')
data = {da: data[da] for da in data_attrs}
- return (data, dict(zip(data_attrs, data_attrs)))
+ return (data, dict(zip(data_attrs, data_attrs)), style)
@@ -131,7 +127,7 @@ class ArrowPlot(CompositeElementPlot):
_plot_methods = dict(single='text')
- def get_data(self, element, ranges=None):
+ def get_data(self, element, ranges, style):
plot = self.state
label_mapping = dict(x='x', y='y', text='text')
@@ -167,7 +163,7 @@ def get_data(self, element, ranges=None):
label_data = dict(x=[x2], y=[y2])
label_data['text'] = [element.text]
return ({'label': label_data},
- {'arrow': arrow_opts, 'label': label_mapping})
+ {'arrow': arrow_opts, 'label': label_mapping}, style)
def _init_glyph(self, plot, mapping, properties, key):
"""
diff --git a/holoviews/plotting/bokeh/chart.py b/holoviews/plotting/bokeh/chart.py
index 02855aa5f5..779bbe273a 100644
--- a/holoviews/plotting/bokeh/chart.py
+++ b/holoviews/plotting/bokeh/chart.py
@@ -71,8 +71,7 @@ def _get_size_data(self, element, ranges, style):
return data, mapping
- def get_data(self, element, ranges=None):
- style = self.style[self.cyclic_index]
+ def get_data(self, element, ranges, style):
dims = element.dimensions(label=True)
xidx, yidx = (1, 0) if self.invert_axes else (0, 1)
@@ -94,17 +93,17 @@ def get_data(self, element, ranges=None):
mapping.update(smapping)
self._get_hover_data(data, element)
- return data, mapping
+ return data, mapping, style
- def get_batched_data(self, element, ranges=None):
+ def get_batched_data(self, element, ranges):
data = defaultdict(list)
zorders = self._updated_zorders(element)
- styles = self.lookup_options(element.last, 'style')
- styles = styles.max_cycles(len(self.ordering))
for (key, el), zorder in zip(element.data.items(), zorders):
self.set_param(**self.lookup_options(el, 'plot').options)
- eldata, elmapping = self.get_data(el, ranges)
+ style = self.lookup_options(element.last, 'style')
+ style = style.max_cycles(len(self.ordering))[zorder]
+ eldata, elmapping, style = self.get_data(el, ranges, style)
for k, eld in eldata.items():
data[k].append(eld)
@@ -114,7 +113,6 @@ def get_batched_data(self, element, ranges=None):
# Apply static styles
nvals = len(list(eldata.values())[0])
- style = styles[zorder]
sdata, smapping = expand_batched_style(style, self._batched_style_opts,
elmapping, nvals)
elmapping.update(smapping)
@@ -127,7 +125,7 @@ def get_batched_data(self, element, ranges=None):
data[sanitized].append([k]*nvals)
data = {k: np.concatenate(v) for k, v in data.items()}
- return data, elmapping
+ return data, elmapping, style
@@ -185,8 +183,7 @@ def _glyph_properties(self, *args):
return properties
- def get_data(self, element, ranges=None):
- style = self.style[self.cyclic_index]
+ def get_data(self, element, ranges, style):
input_scale = style.pop('scale', 1.0)
# Get x, y, angle, magnitude and color data
@@ -242,7 +239,7 @@ def get_data(self, element, ranges=None):
data[cdim.name] = color
mapping.update(cmapping)
- return (data, mapping)
+ return (data, mapping, style)
@@ -259,12 +256,12 @@ class CurvePlot(ElementPlot):
_plot_methods = dict(single='line', batched='multi_line')
_batched_style_opts = line_properties
- def get_data(self, element, ranges=None):
+ def get_data(self, element, ranges, style):
xidx, yidx = (1, 0) if self.invert_axes else (0, 1)
x = element.get_dimension(xidx).name
y = element.get_dimension(yidx).name
if self.static_source:
- return {}, dict(x=x, y=y)
+ return {}, dict(x=x, y=y), style
if 'steps' in self.interpolation:
element = interpolate_curve(element, interpolation=self.interpolation)
@@ -272,7 +269,7 @@ def get_data(self, element, ranges=None):
y: element.dimension_values(yidx)}
self._get_hover_data(data, element)
self._categorize_data(data, (x, y), element.dimensions())
- return (data, dict(x=x, y=y))
+ return (data, dict(x=x, y=y), style)
def _hover_opts(self, element):
if self.batched:
@@ -283,16 +280,15 @@ def _hover_opts(self, element):
line_policy = 'nearest'
return dims, dict(line_policy=line_policy)
- def get_batched_data(self, overlay, ranges=None):
+ def get_batched_data(self, overlay, ranges):
data = defaultdict(list)
zorders = self._updated_zorders(overlay)
- styles = self.lookup_options(overlay.last, 'style')
- styles = styles.max_cycles(len(self.ordering))
-
for (key, el), zorder in zip(overlay.data.items(), zorders):
self.set_param(**self.lookup_options(el, 'plot').options)
- eldata, elmapping = self.get_data(el, ranges)
+ style = self.lookup_options(el, 'style')
+ style = style.max_cycles(len(self.ordering))[zorder]
+ eldata, elmapping, style = self.get_data(el, ranges, style)
# Skip if data empty
if not eldata:
@@ -302,7 +298,6 @@ def get_batched_data(self, overlay, ranges=None):
data[k].append(eld)
# Apply static styles
- style = styles[zorder]
sdata, smapping = expand_batched_style(style, self._batched_style_opts,
elmapping, nvals=1)
elmapping.update(smapping)
@@ -316,7 +311,7 @@ def get_batched_data(self, overlay, ranges=None):
if not any(v is None for v in vals)}
mapping = {{'x': 'xs', 'y': 'ys'}.get(k, k): v
for k, v in elmapping.items()}
- return data, mapping
+ return data, mapping, style
@@ -325,7 +320,7 @@ class HistogramPlot(ElementPlot):
style_opts = line_properties + fill_properties
_plot_methods = dict(single='quad')
- def get_data(self, element, ranges=None):
+ def get_data(self, element, ranges, style):
if self.invert_axes:
mapping = dict(top='left', bottom='right', left=0, right='top')
else:
@@ -336,7 +331,7 @@ def get_data(self, element, ranges=None):
data = dict(top=element.values, left=element.edges[:-1],
right=element.edges[1:])
self._get_hover_data(data, element)
- return (data, mapping)
+ return (data, mapping, style)
def get_extents(self, element, ranges):
x0, y0, x1, y1 = super(HistogramPlot, self).get_extents(element, ranges)
@@ -369,7 +364,7 @@ class SideHistogramPlot(ColorbarPlot, HistogramPlot):
main_source.trigger('change')
"""
- def get_data(self, element, ranges=None):
+ def get_data(self, element, ranges, style):
if self.invert_axes:
mapping = dict(top='right', bottom='left', left=0, right='top')
else:
@@ -389,7 +384,7 @@ def get_data(self, element, ranges=None):
mapping['fill_color'] = {'field': dim.name,
'transform': cmapper}
self._get_hover_data(data, element)
- return (data, mapping)
+ return (data, mapping, style)
def _init_glyph(self, plot, mapping, properties):
@@ -428,10 +423,10 @@ class ErrorPlot(ElementPlot):
_plot_methods = dict(single=Whisker)
- def get_data(self, element, ranges=None):
+ def get_data(self, element, ranges, style):
mapping = dict(self._mapping)
if self.static_source:
- return {}, mapping
+ return {}, mapping, style
base = element.dimension_values(0)
ys = element.dimension_values(1)
@@ -448,7 +443,7 @@ def get_data(self, element, ranges=None):
else:
mapping['dimension'] = 'height'
self._categorize_data(data, ('base',), element.dimensions())
- return (data, mapping)
+ return (data, mapping, style)
def _init_glyph(self, plot, mapping, properties):
@@ -493,10 +488,10 @@ def get_extents(self, element, ranges):
ranges[vdim] = (np.nanmin([0, ranges[vdim][0]]), ranges[vdim][1])
return super(AreaPlot, self).get_extents(element, ranges)
- def get_data(self, element, ranges=None):
+ def get_data(self, element, ranges, style):
mapping = dict(self._mapping)
if self.static_source:
- return {}, mapping
+ return {}, mapping, style
xs = element.dimension_values(0)
if len(element.vdims) > 1:
@@ -510,7 +505,7 @@ def get_data(self, element, ranges=None):
mapping['dimension'] = 'width'
else:
mapping['dimension'] = 'height'
- return data, mapping
+ return data, mapping, style
@@ -554,8 +549,7 @@ def get_extents(self, element, ranges):
t = np.nanmax([0, t])
return l, b, r, t
- def get_data(self, element, ranges=None):
- style = self.style[self.cyclic_index]
+ def get_data(self, element, ranges, style):
dims = element.dimensions(label=True)
pos = self.position
@@ -583,7 +577,7 @@ def get_data(self, element, ranges=None):
for d in dims:
data[dimension_sanitizer(d)] = element.dimension_values(d)
- return data, mapping
+ return data, mapping, style
class SideSpikesPlot(SpikesPlot):
@@ -639,7 +633,7 @@ class BarPlot(ColorbarPlot, LegendPlot):
style_opts = line_properties + fill_properties + ['width', 'cmap']
- _plot_methods = dict(single=('vbar', 'hbar'), batched=('vbar', 'hbar'))
+ _plot_methods = dict(single=('vbar', 'hbar'))
# Declare that y-range should auto-range if not bounded
_y_range_type = DataRange1d
@@ -769,7 +763,7 @@ def _glyph_properties(self, *args):
del props['width']
return props
- def get_data(self, element, ranges):
+ def get_data(self, element, ranges, style):
# Get x, y, group, stack and color dimensions
grouping = None
group_dim = element.get_dimension(self.group_index)
@@ -793,7 +787,6 @@ def get_data(self, element, ranges):
self.color_index = color_dim.name
# Define style information
- style = self.style[self.cyclic_index]
width = style.get('width', 1)
cmap = style.get('cmap')
hover = any(t == 'hover' or isinstance(t, HoverTool)
@@ -946,13 +939,7 @@ def get_data(self, element, ranges):
mapping.update({'y': mapping.pop('x'), 'left': mapping.pop('bottom'),
'right': mapping.pop('top'), 'height': mapping.pop('width')})
- return sanitized_data, mapping
-
- def get_batched_data(self, element, ranges):
- el = element.last
- collapsed = Bars(element.table(), kdims=el.kdims+element.kdims,
- vdims=el.vdims)
- return self.get_data(collapsed, ranges)
+ return sanitized_data, mapping, style
@@ -1001,8 +988,8 @@ def _get_axis_labels(self, *args, **kwargs):
ylabel = element.vdims[0].pprint_label
return xlabel, ylabel, None
- def _glyph_properties(self, plot, element, source, ranges):
- properties = dict(self.style[self.cyclic_index], source=source)
+ def _glyph_properties(self, plot, element, source, ranges, style):
+ properties = dict(style, source=source)
if self.show_legend and not element.kdims:
properties['legend'] = element.label
return properties
@@ -1025,12 +1012,11 @@ def _get_factors(self, element):
xfactors, yfactors = factors, []
return (yfactors, xfactors) if self.invert_axes else (xfactors, yfactors)
- def get_data(self, element, ranges=None):
+ def get_data(self, element, ranges, style):
if element.kdims:
groups = element.groupby(element.kdims).data
else:
groups = dict([(element.label, element)])
- style = self.style[self.cyclic_index]
vdim = dimension_sanitizer(element.vdims[0].name)
# Define CDS data
@@ -1131,7 +1117,7 @@ def get_data(self, element, ranges=None):
# Return if not grouped
if not element.kdims:
- return data, mapping
+ return data, mapping, style
# Define color dimension and data
if cidx is None or cidx>=element.ndims:
@@ -1156,5 +1142,5 @@ def get_data(self, element, ranges=None):
vbar2_map['fill_color'] = {'field': cname, 'transform': mapper}
vbar_map['legend'] = cdim.name
- return data, mapping
+ return data, mapping, style
diff --git a/holoviews/plotting/bokeh/element.py b/holoviews/plotting/bokeh/element.py
index e83f440cd6..57d0a2d4a0 100644
--- a/holoviews/plotting/bokeh/element.py
+++ b/holoviews/plotting/bokeh/element.py
@@ -656,9 +656,8 @@ def _init_glyph(self, plot, mapping, properties):
return renderer, renderer.glyph
- def _glyph_properties(self, plot, element, source, ranges):
- properties = self.style[self.cyclic_index]
-
+ def _glyph_properties(self, plot, element, source, ranges, style):
+ properties = dict(style, source=source)
if self.show_legend:
if self.overlay_dims:
legend = ', '.join([d.pprint_value(v) for d, v in
@@ -666,7 +665,6 @@ def _glyph_properties(self, plot, element, source, ranges):
else:
legend = element.label
properties['legend'] = legend
- properties['source'] = source
return properties
def _update_glyph(self, renderer, properties, mapping, glyph):
@@ -737,16 +735,17 @@ def _init_glyphs(self, plot, element, ranges, source):
# Get data and initialize data source
if self.batched:
current_id = tuple(element.traverse(lambda x: x._plot_id, [Element]))
- data, mapping = self.get_batched_data(element, ranges)
+ data, mapping, style = self.get_batched_data(element, ranges)
else:
- data, mapping = self.get_data(element, ranges)
+ style = self.style[self.cyclic_index]
+ data, mapping, style = self.get_data(element, ranges, style)
current_id = element._plot_id
if source is None:
source = self._init_datasource(data)
self.handles['previous_id'] = current_id
self.handles['source'] = source
- properties = self._glyph_properties(plot, style_element, source, ranges)
+ properties = self._glyph_properties(plot, style_element, source, ranges, style)
with abbreviated_exception():
renderer, glyph = self._init_glyph(plot, mapping, properties)
self.handles['glyph'] = glyph
@@ -819,16 +818,17 @@ def _update_glyphs(self, element, ranges):
current_id = element._plot_id
self.handles['previous_id'] = current_id
self.static_source = (self.dynamic and (current_id == previous_id))
+ style = self.style[self.cyclic_index]
if self.batched:
- data, mapping = self.get_batched_data(element, ranges)
+ data, mapping, style = self.get_batched_data(element, ranges)
else:
- data, mapping = self.get_data(element, ranges)
+ data, mapping, style = self.get_data(element, ranges, style)
if not self.static_source:
self._update_datasource(source, data)
if glyph:
- properties = self._glyph_properties(plot, element, source, ranges)
+ properties = self._glyph_properties(plot, element, source, ranges, style)
renderer = self.handles.get('glyph_renderer')
with abbreviated_exception():
self._update_glyph(renderer, properties, mapping, glyph)
@@ -964,18 +964,15 @@ class CompositeElementPlot(ElementPlot):
def _init_glyphs(self, plot, element, ranges, source):
# Get data and initialize data source
- if self.batched:
- current_id = tuple(element.traverse(lambda x: x._plot_id, [Element]))
- data, mapping = self.get_batched_data(element, ranges)
- else:
- data, mapping = self.get_data(element, ranges)
- current_id = element._plot_id
+ style = self.style[self.cyclic_index]
+ data, mapping, style = self.get_data(element, ranges, style)
+ current_id = element._plot_id
self.handles['previous_id'] = current_id
for key in dict(mapping, **data):
source = self._init_datasource(data.get(key, {}))
self.handles[key+'_source'] = source
- properties = self._glyph_properties(plot, element, source, ranges)
+ properties = self._glyph_properties(plot, element, source, ranges, style)
properties = self._process_properties(key, properties)
with abbreviated_exception():
renderer, glyph = self._init_glyph(plot, mapping.get(key, {}), properties, key)
@@ -1015,7 +1012,8 @@ def _update_glyphs(self, element, ranges):
current_id = element._plot_id
self.handles['previous_id'] = current_id
self.static_source = (self.dynamic and (current_id == previous_id))
- data, mapping = self.get_data(element, ranges)
+ style = self.style[self.cyclic_index]
+ data, mapping, style = self.get_data(element, ranges, style)
for key in dict(mapping, **data):
gdata = data[key]
@@ -1025,7 +1023,7 @@ def _update_glyphs(self, element, ranges):
self._update_datasource(source, gdata)
if glyph:
- properties = self._glyph_properties(plot, element, source, ranges)
+ properties = self._glyph_properties(plot, element, source, ranges, style)
properties = self._process_properties(key, properties)
renderer = self.handles.get(key+'_glyph_renderer')
with abbreviated_exception():
diff --git a/holoviews/plotting/bokeh/graphs.py b/holoviews/plotting/bokeh/graphs.py
index 505213aece..b633d442ea 100644
--- a/holoviews/plotting/bokeh/graphs.py
+++ b/holoviews/plotting/bokeh/graphs.py
@@ -84,8 +84,7 @@ def _get_axis_labels(self, *args, **kwargs):
xlabel, ylabel = [kd.pprint_label for kd in element.nodes.kdims[:2]]
return xlabel, ylabel, None
- def get_data(self, element, ranges=None):
- style = self.style[self.cyclic_index]
+ def get_data(self, element, ranges, style):
xidx, yidx = (1, 0) if self.invert_axes else (0, 1)
# Get node data
@@ -135,7 +134,7 @@ def get_data(self, element, ranges=None):
data = {'scatter_1': point_data, 'multi_line_1': path_data, 'layout': layout}
mapping = {'scatter_1': point_mapping, 'multi_line_1': {}}
- return data, mapping
+ return data, mapping, style
def _update_datasource(self, source, data):
@@ -150,14 +149,15 @@ def _update_datasource(self, source, data):
def _init_glyphs(self, plot, element, ranges, source):
# Get data and initialize data source
- data, mapping = self.get_data(element, ranges)
+ style = self.style[self.cyclic_index]
+ data, mapping, style = self.get_data(element, ranges, style)
self.handles['previous_id'] = element._plot_id
properties = {}
mappings = {}
for key in mapping:
source = self._init_datasource(data.get(key, {}))
self.handles[key+'_source'] = source
- glyph_props = self._glyph_properties(plot, element, source, ranges)
+ glyph_props = self._glyph_properties(plot, element, source, ranges, style)
properties.update(glyph_props)
mappings.update(mapping.get(key, {}))
properties = {p: v for p, v in properties.items() if p not in ('legend', 'source')}
diff --git a/holoviews/plotting/bokeh/path.py b/holoviews/plotting/bokeh/path.py
index 804edb1895..805202b73a 100644
--- a/holoviews/plotting/bokeh/path.py
+++ b/holoviews/plotting/bokeh/path.py
@@ -41,7 +41,7 @@ def _get_hover_data(self, data, element):
data[dim] = [v for _ in range(len(list(data.values())[0]))]
- def get_data(self, element, ranges=None):
+ def get_data(self, element, ranges, style):
if self.static_source:
data = {}
else:
@@ -50,7 +50,7 @@ def get_data(self, element, ranges=None):
xs, ys = ([path[:, idx] for path in paths] for idx in [xidx, yidx])
data = dict(xs=xs, ys=ys)
self._get_hover_data(data, element)
- return data, dict(self._mapping)
+ return data, dict(self._mapping), style
def _categorize_data(self, data, cols, dims):
@@ -75,13 +75,12 @@ def get_batched_data(self, element, ranges=None):
data = defaultdict(list)
zorders = self._updated_zorders(element)
- styles = self.lookup_options(element.last, 'style')
- styles = styles.max_cycles(len(self.ordering))
-
for (key, el), zorder in zip(element.data.items(), zorders):
self.set_param(**self.lookup_options(el, 'plot').options)
+ style = self.lookup_options(el, 'style')
+ style = style.max_cycles(len(self.ordering))[zorder]
self.overlay_dims = dict(zip(element.kdims, key))
- eldata, elmapping = self.get_data(el, ranges)
+ eldata, elmapping, style = self.get_data(el, ranges, style)
for k, eld in eldata.items():
data[k].extend(eld)
@@ -91,24 +90,22 @@ def get_batched_data(self, element, ranges=None):
# Apply static styles
nvals = len(list(eldata.values())[0])
- style = styles[zorder]
sdata, smapping = expand_batched_style(style, self._batched_style_opts,
elmapping, nvals)
elmapping.update({k: v for k, v in smapping.items() if k not in elmapping})
for k, v in sdata.items():
data[k].extend(list(v))
- return data, elmapping
+ return data, elmapping, style
class ContourPlot(ColorbarPlot, PathPlot):
style_opts = line_properties + ['cmap']
- def get_data(self, element, ranges=None):
- data, mapping = super(ContourPlot, self).get_data(element, ranges)
+ def get_data(self, element, ranges, style):
+ data, mapping, style = super(ContourPlot, self).get_data(element, ranges, style)
ncontours = len(list(data.values())[0])
- style = self.style[self.cyclic_index]
if element.vdims and element.level is not None:
cdim = element.vdims[0]
dim_name = util.dimension_sanitizer(cdim.name)
@@ -117,7 +114,7 @@ def get_data(self, element, ranges=None):
if 'cmap' in style:
cmapper = self._get_colormapper(cdim, element, ranges, style)
mapping['line_color'] = {'field': dim_name, 'transform': cmapper}
- return data, mapping
+ return data, mapping, style
class PolygonPlot(ColorbarPlot, PathPlot):
@@ -135,7 +132,7 @@ def _hover_opts(self, element):
dims += element.vdims
return dims, {}
- def get_data(self, element, ranges=None):
+ def get_data(self, element, ranges, style):
if self.static_source:
data = {}
else:
@@ -144,7 +141,6 @@ def get_data(self, element, ranges=None):
ys = [path[:, 1] for path in paths]
data = dict(xs=ys, ys=xs) if self.invert_axes else dict(xs=xs, ys=ys)
- style = self.style[self.cyclic_index]
mapping = dict(self._mapping)
if element.vdims and element.level is not None:
cdim = element.vdims[0]
@@ -161,4 +157,4 @@ def get_data(self, element, ranges=None):
data[dim] = [v for _ in range(len(xs))]
data[dim_name] = [element.level for _ in range(len(xs))]
- return data, mapping
+ return data, mapping, style
diff --git a/holoviews/plotting/bokeh/plot.py b/holoviews/plotting/bokeh/plot.py
index 140906f2fe..4bfe995c0e 100644
--- a/holoviews/plotting/bokeh/plot.py
+++ b/holoviews/plotting/bokeh/plot.py
@@ -94,7 +94,7 @@ def __init__(self, *args, **params):
self.root = None
- def get_data(self, element, ranges=None):
+ def get_data(self, element, ranges, style):
"""
Returns the data from an element in the appropriate format for
initializing or updating a ColumnDataSource and a dictionary
diff --git a/holoviews/plotting/bokeh/raster.py b/holoviews/plotting/bokeh/raster.py
index 4cb822f7b6..deab46c6b8 100644
--- a/holoviews/plotting/bokeh/raster.py
+++ b/holoviews/plotting/bokeh/raster.py
@@ -21,21 +21,13 @@ def __init__(self, *args, **kwargs):
if self.hmap.type == Raster:
self.invert_yaxis = not self.invert_yaxis
-
- def _glyph_properties(self, plot, element, source, ranges):
- properties = super(RasterPlot, self)._glyph_properties(plot, element,
- source, ranges)
- properties = {k: v for k, v in properties.items()}
+ def get_data(self, element, ranges, style):
+ mapping = dict(image='image', x='x', y='y', dw='dw', dh='dh')
val_dim = [d for d in element.vdims][0]
- properties['color_mapper'] = self._get_colormapper(val_dim, element, ranges,
- properties)
- return properties
-
+ style['color_mapper'] = self._get_colormapper(val_dim, element, ranges, style)
- def get_data(self, element, ranges=None):
- mapping = dict(image='image', x='x', y='y', dw='dw', dh='dh')
if self.static_source:
- return {}, mapping
+ return {}, mapping, style
img = element.dimension_values(2, flat=False)
if img.dtype.kind == 'b':
@@ -60,9 +52,9 @@ def get_data(self, element, ranges=None):
img = img[::-1]
b, t = t, b
dh, dw = t-b, r-l
-
+
data = dict(image=[img], x=[l], y=[b], dw=[dw], dh=[dh])
- return (data, mapping)
+ return (data, mapping, style)
@@ -71,10 +63,10 @@ class RGBPlot(RasterPlot):
style_opts = []
_plot_methods = dict(single='image_rgba')
- def get_data(self, element, ranges=None):
+ def get_data(self, element, ranges, style):
mapping = dict(image='image', x='x', y='y', dw='dw', dh='dh')
if self.static_source:
- return {}, mapping
+ return {}, mapping, style
img = np.dstack([element.dimension_values(d, flat=False)
for d in element.vdims])
@@ -104,16 +96,16 @@ def get_data(self, element, ranges=None):
dh, dw = t-b, r-l
data = dict(image=[img], x=[l], y=[b], dw=[dw], dh=[dh])
- return (data, mapping)
+ return (data, mapping, style)
- def _glyph_properties(self, plot, element, source, ranges):
+ def _glyph_properties(self, plot, element, source, ranges, style):
return ElementPlot._glyph_properties(self, plot, element,
- source, ranges)
+ source, ranges, style)
class HSVPlot(RGBPlot):
- def get_data(self, element, ranges=None):
- return super(HSVPlot, self).get_data(element.rgb, ranges)
+ def get_data(self, element, ranges, style):
+ return super(HSVPlot, self).get_data(element.rgb, ranges, style)
class HeatMapPlot(ColorbarPlot):
@@ -138,13 +130,12 @@ class HeatMapPlot(ColorbarPlot):
def _get_factors(self, element):
return super(HeatMapPlot, self)._get_factors(element.gridded)
- def get_data(self, element, ranges=None):
+ def get_data(self, element, ranges, style):
x, y, z = [dimension_sanitizer(d) for d in element.dimensions(label=True)[:3]]
if self.invert_axes: x, y = y, x
- style = self.style[self.cyclic_index]
cmapper = self._get_colormapper(element.vdims[0], element, ranges, style)
if self.static_source:
- return {}, {'x': x, 'y': y, 'fill_color': {'field': 'zvalues', 'transform': cmapper}}
+ return {}, {'x': x, 'y': y, 'fill_color': {'field': 'zvalues', 'transform': cmapper}}, style
aggregate = element.gridded
xdim, ydim = aggregate.dimensions()[:2]
@@ -168,7 +159,7 @@ def get_data(self, element, ranges=None):
data[sanitized] = ['-' if is_nan(v) else vdim.pprint_value(v)
for v in aggregate.dimension_values(vdim)]
return (data, {'x': x, 'y': y, 'fill_color': {'field': 'zvalues', 'transform': cmapper},
- 'height': 1, 'width': 1})
+ 'height': 1, 'width': 1}, style)
class QuadMeshPlot(ColorbarPlot):
@@ -179,13 +170,12 @@ class QuadMeshPlot(ColorbarPlot):
_plot_methods = dict(single='rect')
style_opts = ['cmap', 'color'] + line_properties + fill_properties
- def get_data(self, element, ranges=None):
+ def get_data(self, element, ranges, style):
x, y, z = element.dimensions(label=True)
if self.invert_axes: x, y = y, x
- style = self.style[self.cyclic_index]
cmapper = self._get_colormapper(element.vdims[0], element, ranges, style)
if self.static_source:
- return {}, {'x': x, 'y': y, 'fill_color': {'field': z, 'transform': cmapper}}
+ return {}, {'x': x, 'y': y, 'fill_color': {'field': z, 'transform': cmapper}}, style
if len(set(v.shape for v in element.data)) == 1:
raise SkipRendering("Bokeh QuadMeshPlot only supports rectangular meshes")
@@ -204,4 +194,4 @@ def get_data(self, element, ranges=None):
data = {x: xs, y: ys, z: zvals, 'widths': ws, 'heights': hs}
return (data, {'x': x, 'y': y,
'fill_color': {'field': z, 'transform': cmapper},
- 'height': 'heights', 'width': 'widths'})
+ 'height': 'heights', 'width': 'widths'}, style)
diff --git a/holoviews/plotting/bokeh/tabular.py b/holoviews/plotting/bokeh/tabular.py
index e20aa275de..4e1e7fe3c9 100644
--- a/holoviews/plotting/bokeh/tabular.py
+++ b/holoviews/plotting/bokeh/tabular.py
@@ -45,13 +45,13 @@ def _execute_hooks(self, element):
self.warning("Plotting hook %r could not be applied:\n\n %s" % (hook, e))
- def get_data(self, element, ranges=None):
+ def get_data(self, element, ranges, style):
dims = element.dimensions()
mapping = {d.name: d.name for d in dims}
data = {d: element.dimension_values(d) for d in dims}
data = {d.name: values if values.dtype.kind in "if" else list(map(d.pprint_value, values))
for d, values in data.items()}
- return data, mapping
+ return data, mapping, style
def initialize_plot(self, ranges=None, plot=None, plots=None, source=None):
@@ -64,18 +64,18 @@ def initialize_plot(self, ranges=None, plot=None, plots=None, source=None):
self.current_frame = element
self.current_key = key
- data, _ = self.get_data(element, ranges)
+ style = self.lookup_options(element, 'style')[self.cyclic_index]
+ data, _, style = self.get_data(element, ranges, style)
if source is None:
source = self._init_datasource(data)
self.handles['source'] = source
dims = element.dimensions()
columns = [TableColumn(field=d.name, title=d.pprint_label) for d in dims]
- properties = self.lookup_options(element, 'style')[self.cyclic_index]
if bokeh_version > '0.12.7':
- properties['reorderable'] = False
+ style['reorderable'] = False
table = DataTable(source=source, columns=columns, height=self.height,
- width=self.width, **properties)
+ width=self.width, **style)
self.handles['plot'] = table
self.handles['glyph_renderer'] = table
self._execute_hooks(element)
@@ -118,5 +118,6 @@ def update_frame(self, key, ranges=None, plot=None):
if self.static_source:
return
source = self.handles['source']
- data, _ = self.get_data(element, ranges)
+ style = self.lookup_options(element, 'style')[self.cyclic_index]
+ data, _, style = self.get_data(element, ranges, style)
self._update_datasource(source, data)