Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support by reductions in datashader operations #4438

Merged
merged 1 commit into from
May 26, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 11 additions & 3 deletions holoviews/operation/datashader.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,7 +252,7 @@ def _get_aggregator(self, element, add_field=True):
agg = self._agg_methods[agg]()

elements = element.traverse(lambda x: x, [Element])
if add_field and agg.column is None and not isinstance(agg, (rd.count, rd.any)):
if add_field and getattr(agg, 'column', False) is None and not isinstance(agg, (rd.count, rd.any)):
if not elements:
raise ValueError('Could not find any elements to apply '
'%s operation to.' % type(self).__name__)
Expand Down Expand Up @@ -291,6 +291,10 @@ def _get_agg_params(self, element, x, y, agg_fn, bounds):
params = dict(get_param_values(element), kdims=[x, y],
datatype=['xarray'], bounds=bounds)

category = None
if hasattr(agg_fn, 'reduction'):
category = agg_fn.cat_column
agg_fn = agg_fn.reduction
column = agg_fn.column if agg_fn else None
if column:
dims = [d for d in element.dimensions('ranges') if d == column]
Expand All @@ -300,6 +304,8 @@ def _get_agg_params(self, element, x, y, agg_fn, bounds):
"dimension." % (column,element))
name = '%s Count' % column if isinstance(agg_fn, ds.count_cat) else column
vdims = [dims[0].clone(name)]
elif category:
vdims = Dimension('%s Count' % category)
else:
vdims = Dimension('Count')
params['vdims'] = vdims
Expand Down Expand Up @@ -332,7 +338,6 @@ class aggregate(AggregationOperation):
the linked plot.
"""


@classmethod
def get_agg_data(cls, obj, category=None):
"""
Expand Down Expand Up @@ -409,7 +414,10 @@ def get_agg_data(cls, obj, category=None):

def _process(self, element, key=None):
agg_fn = self._get_aggregator(element)
category = agg_fn.column if isinstance(agg_fn, ds.count_cat) else None
if hasattr(agg_fn, 'cat_column'):
category = agg_fn.cat_column
else:
category = agg_fn.column if isinstance(agg_fn, ds.count_cat) else None

if overlay_aggregate.applies(element, agg_fn):
params = dict(
Expand Down
2 changes: 1 addition & 1 deletion holoviews/operation/element.py
Original file line number Diff line number Diff line change
Expand Up @@ -640,7 +640,7 @@ class histogram(Operation):
bin_range = param.NumericTuple(default=None, length=2, doc="""
Specifies the range within which to compute the bins.""")

bins = param.ClassSelector(default=None, class_=(np.ndarray, list, tuple), doc="""
bins = param.ClassSelector(default=None, class_=(np.ndarray, list, tuple, str), doc="""
An explicit set of bin edges.""")

cumulative = param.Boolean(default=False, doc="""
Expand Down
46 changes: 40 additions & 6 deletions holoviews/tests/operation/testdatashader.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,12 @@
from unittest import SkipTest, skipIf

import numpy as np

from holoviews import (Dimension, Curve, Points, Image, Dataset, RGB, Path,
Graph, TriMesh, QuadMesh, NdOverlay, Contours, Spikes,
Spread, Area, Segments, Polygons)
from holoviews.element.comparison import ComparisonTestCase
from numpy import nan

try:
import datashader as ds
Expand Down Expand Up @@ -35,6 +37,8 @@
cudf_skip = skipIf(cudf is None, "cuDF not available")




class DatashaderAggregateTests(ComparisonTestCase):
"""
Tests for datashader aggregation
Expand Down Expand Up @@ -620,6 +624,36 @@ def test_multi_poly_rasterize(self):



class DatashaderCatAggregateTests(ComparisonTestCase):

def setUp(self):
if ds_version < '0.11.0':
raise SkipTest('Regridding operations require datashader>=0.11.0')

def test_aggregate_points_categorical(self):
points = Points([(0.2, 0.3, 'A'), (0.4, 0.7, 'B'), (0, 0.99, 'C')], vdims='z')
img = aggregate(points, dynamic=False, x_range=(0, 1), y_range=(0, 1),
width=2, height=2, aggregator=ds.by('z', ds.count()))
xs, ys = [0.25, 0.75], [0.25, 0.75]
expected = NdOverlay({'A': Image((xs, ys, [[1, 0], [0, 0]]), vdims='z Count'),
'B': Image((xs, ys, [[0, 0], [1, 0]]), vdims='z Count'),
'C': Image((xs, ys, [[0, 0], [1, 0]]), vdims='z Count')},
kdims=['z'])
self.assertEqual(img, expected)


def test_aggregate_points_categorical_mean(self):
points = Points([(0.2, 0.3, 'A', 0.1), (0.4, 0.7, 'B', 0.2), (0, 0.99, 'C', 0.3)], vdims=['cat', 'z'])
img = aggregate(points, dynamic=False, x_range=(0, 1), y_range=(0, 1),
width=2, height=2, aggregator=ds.by('cat', ds.mean('z')))
xs, ys = [0.25, 0.75], [0.25, 0.75]
expected = NdOverlay({'A': Image((xs, ys, [[0.1, nan], [nan, nan]]), vdims='z'),
'B': Image((xs, ys, [[nan, nan], [0.2, nan]]), vdims='z'),
'C': Image((xs, ys, [[nan, nan], [0.3, nan]]), vdims='z')},
kdims=['cat'])
self.assertEqual(img, expected)



class DatashaderShadeTests(ComparisonTestCase):

Expand All @@ -633,9 +667,9 @@ def test_shade_categorical_images_xarray(self):
datatype=['xarray'], vdims='z Count')},
kdims=['z'])
shaded = shade(data)
r = [[228, 255], [66, 255]]
g = [[26, 255], [150, 255]]
b = [[28, 255], [129, 255]]
r = [[228, 120], [66, 120]]
g = [[26, 109], [150, 109]]
b = [[28, 95], [129, 95]]
a = [[40, 0], [255, 0]]
expected = RGB((xs, ys, r, g, b, a), datatype=['grid'],
vdims=RGB.vdims+[Dimension('A', range=(0, 1))])
Expand All @@ -651,9 +685,9 @@ def test_shade_categorical_images_grid(self):
datatype=['grid'], vdims='z Count')},
kdims=['z'])
shaded = shade(data)
r = [[228, 255], [66, 255]]
g = [[26, 255], [150, 255]]
b = [[28, 255], [129, 255]]
r = [[228, 120], [66, 120]]
g = [[26, 109], [150, 109]]
b = [[28, 95], [129, 95]]
a = [[40, 0], [255, 0]]
expected = RGB((xs, ys, r, g, b, a), datatype=['grid'],
vdims=RGB.vdims+[Dimension('A', range=(0, 1))])
Expand Down