Skip to content

Commit

Permalink
Merge pull request #963 from ioam/datashader_extensible
Browse files Browse the repository at this point in the history
Small fixes and optimizations for datashader operations
  • Loading branch information
jbednar authored Nov 1, 2016
2 parents beb27db + 9dac3fa commit e879523
Showing 1 changed file with 39 additions and 19 deletions.
58 changes: 39 additions & 19 deletions holoviews/operation/datashader.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def discover(dataset):
Allows datashader to correctly discover the dtypes of the data
in a holoviews Element.
"""
if isinstance(dataset.interface, (PandasInterface, ArrayInterface)):
if dataset.interface in [PandasInterface, ArrayInterface]:
return dsdiscover(dataset.data)
else:
return dsdiscover(dataset.dframe())
Expand All @@ -54,21 +54,11 @@ def dataset_pipeline(dataset, schema, canvas, glyph, summary):
vdims = [dataset.get_dimension(column)(name) if column
else Dimension('Count')]

agg = pandas_pipeline(dataset.dframe(), schema, canvas,
glyph, summary)
agg = pandas_pipeline(dataset.data, schema, canvas,
glyph, summary)
agg = agg.rename({'x_axis': kdims[0].name,
'y_axis': kdims[1].name})

params = dict(get_param_values(dataset), kdims=kdims,
datatype=['xarray'], vdims=vdims)

if agg.ndim == 2:
return GridImage(agg, **params)
else:
return NdOverlay({c: GridImage(agg.sel(**{column: c}),
**params)
for c in agg.coords[column].data},
kdims=[dataset.get_dimension(column)])
return agg


class aggregate(ElementOperation):
Expand Down Expand Up @@ -120,6 +110,11 @@ class aggregate(ElementOperation):
List of streams that are applied if dynamic=True, allowing
for dynamic interaction with the plot.""")

element_type = param.ClassSelector(class_=(Dataset,), instantiate=False,
is_instance=False, default=GridImage,
doc="""
The type of the returned Elements, must be a 2D Dataset type.""")

@classmethod
def get_agg_data(cls, obj, category=None):
"""
Expand All @@ -130,6 +125,7 @@ def get_agg_data(cls, obj, category=None):
kdims = obj.kdims
vdims = obj.vdims
x, y = obj.dimensions(label=True)[:2]
is_df = lambda x: isinstance(x, Dataset) and x.interface is PandasInterface
if isinstance(obj, Path):
glyph = 'line'
for p in obj.data:
Expand All @@ -140,20 +136,23 @@ def get_agg_data(cls, obj, category=None):
elif isinstance(obj, CompositeOverlay):
for key, el in obj.data.items():
x, y, element, glyph = cls.get_agg_data(el)
df = element.dframe()
df = element.data if is_df(element) else element.dframe()
if isinstance(obj, NdOverlay):
df = df.assign(**dict(zip(obj.dimensions('key', True), key)))
paths.append(df)
kdims += element.kdims
vdims = element.vdims
elif isinstance(obj, Element):
glyph = 'line' if isinstance(obj, Curve) else 'points'
paths.append(obj.dframe())
paths.append(obj.data if is_df(obj) else obj.dframe())
if glyph == 'line':
empty = paths[0][:1].copy()
empty.loc[0, :] = (np.NaN,) * empty.shape[1]
paths = [elem for path in paths for elem in (path, empty)][:-1]
df = pd.concat(paths).reset_index(drop=True)
if len(paths) > 1:
df = pd.concat(paths).reset_index(drop=True)
else:
df = paths[0]
if category and df[category].dtype.name != 'category':
df[category] = df[category].astype('category')
return x, y, Dataset(df, kdims=kdims, vdims=vdims), glyph
Expand All @@ -178,7 +177,26 @@ def _process(self, element, key=None):

cvs = ds.Canvas(plot_width=width, plot_height=height,
x_range=(xstart, xend), y_range=(ystart, yend))
return getattr(cvs, glyph)(data, x, y, self.p.aggregator)

column = agg_fn.column
if column and isinstance(agg_fn, ds.count_cat):
name = '%s Count' % agg_fn.column
else:
name = column
vdims = [element.get_dimension(column)(name) if column
else Dimension('Count')]
params = dict(get_param_values(element), kdims=[element.dimensions()[0:2]],
datatype=['xarray'], vdims=vdims)

agg = getattr(cvs, glyph)(data, x, y, self.p.aggregator)
if agg.ndim == 2:
return self.p.element_type(agg, **params)
else:
return NdOverlay({c: self.p.element_type(agg.sel(**{column: c}),
**params)
for c in agg.coords[column].data},
kdims=[data.get_dimension(column)])




Expand All @@ -195,7 +213,7 @@ class shade(ElementOperation):
Iterable or a Callable.
"""

cmap = param.ClassSelector(class_=(Iterable, Callable), doc="""
cmap = param.ClassSelector(class_=(Iterable, Callable, dict), doc="""
Iterable or callable which returns colors as hex colors.
Callable type must allow mapping colors between 0 and 1.""")

Expand Down Expand Up @@ -259,6 +277,8 @@ def _process(self, element, key=None):
categories = array.shape[-1]
if not self.p.cmap:
pass
elif isinstance(self.p.cmap, dict):
shade_opts['color_key'] = self.p.cmap
elif isinstance(self.p.cmap, Iterable):
shade_opts['color_key'] = [c for i, c in
zip(range(categories), self.p.cmap)]
Expand Down

0 comments on commit e879523

Please sign in to comment.