Skip to content

Commit

Permalink
Implement optimized TriMesh wireframe rendering (#3495)
Browse files Browse the repository at this point in the history
  • Loading branch information
philippjfr authored Feb 15, 2019
1 parent e76fb6b commit fe4c8bb
Show file tree
Hide file tree
Showing 2 changed files with 67 additions and 22 deletions.
74 changes: 55 additions & 19 deletions holoviews/operation/datashader.py
Original file line number Diff line number Diff line change
Expand Up @@ -537,7 +537,7 @@ class regrid(AggregationOperation):
being overlaid on a much larger background.""")

interpolation = param.ObjectSelector(default='nearest',
objects=['linear', 'nearest'], doc="""
objects=['linear', 'nearest', 'bilinear', None, False], doc="""
Interpolation method""")

upsample = param.Boolean(default=False, doc="""
Expand Down Expand Up @@ -595,7 +595,9 @@ def _process(self, element, key=None):
# Disable upsampling by clipping size and ranges
(xstart, xend), (ystart, yend) = (x_range, y_range)
xspan, yspan = (xend-xstart), (yend-ystart)
if not self.p.upsample and self.p.target is None:
interp = self.p.interpolation or None
if interp == 'bilinear': interp = 'linear'
if not (self.p.upsample or interp is None) and self.p.target is None:
(x0, x1), (y0, y1) = element.range(0), element.range(1)
if isinstance(x0, datetime_types):
x0, x1 = dt_to_int(x0, 'ns'), dt_to_int(x1, 'ns')
Expand Down Expand Up @@ -638,7 +640,7 @@ def _process(self, element, key=None):
arrays = self._get_xarrays(element, coords, xtype, ytype)
agg_fn = self._get_aggregator(element, add_field=False)
for vd, xarr in arrays.items():
rarray = cvs.raster(xarr, upsample_method=self.p.interpolation,
rarray = cvs.raster(xarr, upsample_method=interp,
downsample_method=agg_fn)

# Convert datetime coordinates
Expand Down Expand Up @@ -683,7 +685,7 @@ class trimesh_rasterize(aggregate):
class_=(ds.reductions.Reduction, basestring))

interpolation = param.ObjectSelector(default='bilinear',
objects=['bilinear', None], doc="""
objects=['bilinear', 'linear', None, False], doc="""
The interpolation method to apply during rasterization.""")

def _precompute(self, element, agg):
Expand All @@ -700,6 +702,16 @@ def _precompute(self, element, agg):
return {'mesh': mesh(verts, simplices), 'simplices': simplices,
'vertices': verts}

def _precompute_wireframe(self, element, agg):
if hasattr(element, '_wireframe'):
segments = element._wireframe.data
else:
simplexes = element.array([0, 1, 2, 0]).astype('int')
verts = element.nodes.array([0, 1])
segments = pd.DataFrame(verts[simplexes].reshape(len(simplexes), -1),
columns=['x0', 'y0', 'x1', 'y1', 'x2', 'y2', 'x3', 'y3'])
element._wireframe = Dataset(segments, datatype=['dataframe', 'dask'])
return {'segments': segments}

def _process(self, element, key=None):
if isinstance(element, TriMesh):
Expand All @@ -710,17 +722,28 @@ def _process(self, element, key=None):
(x_range, y_range), (xs, ys), (width, height), (xtype, ytype) = info

agg = self.p.aggregator
if getattr(agg, 'column', None):
interp = self.p.interpolation or None
precompute = self.p.precompute
if interp == 'linear': interp = 'bilinear'
wireframe = False
if (not (element.vdims or (isinstance(element, TriMesh) and element.nodes.vdims))) and ds_version <= '0.6.9':
self.p.aggregator = ds.any() if isinstance(agg, ds.any) or agg == 'any' else ds.count()
return aggregate._process(self, element, key)
elif ((not interp and (isinstance(agg, (ds.any, ds.count)) or
agg in ['any', 'count']))
or not (element.vdims or element.nodes.vdims)):
wireframe = True
precompute = False # TriMesh itself caches wireframe
agg = self._get_aggregator(element) if isinstance(agg, (ds.any, ds.count)) else ds.any()
vdim = 'Count' if isinstance(agg, ds.count) else 'Any'
elif getattr(agg, 'column', None):
if agg.column in element.vdims:
vdim = element.get_dimension(agg.column)
elif isinstance(element, TriMesh) and agg.column in element.nodes.vdims:
vdim = element.nodes.get_dimension(agg.column)
else:
raise ValueError("Aggregation column %s not found on TriMesh element."
% agg.column)
elif not (element.vdims or (isinstance(element, TriMesh) and element.nodes.vdims)):
self.p.aggregator = ds.count() if not isinstance(agg, ds.any) else agg
return aggregate._process(self, element, key)
else:
if isinstance(element, TriMesh) and element.nodes.vdims:
vdim = element.nodes.vdims[0]
Expand All @@ -730,6 +753,8 @@ def _process(self, element, key=None):

if element._plot_id in self._precomputed:
precomputed = self._precomputed[element._plot_id]
elif wireframe:
precomputed = self._precompute_wireframe(element, agg)
else:
precomputed = self._precompute(element, agg)

Expand All @@ -742,17 +767,25 @@ def _process(self, element, key=None):
bounds = (x_range[0], y_range[0], x_range[1], y_range[1])
return Image((xs, ys, np.zeros((height, width))), bounds=bounds, **params)

simplices = precomputed['simplices']
pts = precomputed['vertices']
mesh = precomputed['mesh']
if self.p.precompute:
if wireframe:
segments = precomputed['segments']
else:
simplices = precomputed['simplices']
pts = precomputed['vertices']
mesh = precomputed['mesh']
if precompute:
self._precomputed = {element._plot_id: precomputed}

cvs = ds.Canvas(plot_width=width, plot_height=height,
x_range=x_range, y_range=y_range)
interpolate = bool(self.p.interpolation)
agg = cvs.trimesh(pts, simplices, agg=agg,
interp=interpolate, mesh=mesh)
if wireframe:
agg = cvs.line(segments, x=['x0', 'x1', 'x2', 'x3'],
y=['y0', 'y1', 'y2', 'y3'], axis=1,
agg=agg)
else:
interpolate = bool(self.p.interpolation)
agg = cvs.trimesh(pts, simplices, agg=agg,
interp=interpolate, mesh=mesh)
return Image(agg, **params)


Expand Down Expand Up @@ -795,9 +828,11 @@ class rasterize(AggregationOperation):
aggregator = param.ClassSelector(class_=(ds.reductions.Reduction, basestring),
default=None)

interpolation = param.ObjectSelector(default='bilinear',
objects=['bilinear', None], doc="""
The interpolation method to apply during rasterization.""")
interpolation = param.ObjectSelector(
default='bilinear', objects=['linear', 'nearest', 'bilinear', None, False], doc="""
The interpolation method to apply during rasterization.
Defaults to linear interpolation and None and False are aliases
of each other.""")

_transforms = [(Image, regrid),
(TriMesh, trimesh_rasterize),
Expand All @@ -814,7 +849,8 @@ class rasterize(AggregationOperation):
def _process(self, element, key=None):
for predicate, transform in self._transforms:
op_params = dict({k: v for k, v in self.p.items()
if k in transform.params() and v is not None},
if k in transform.params()
and not (v is None and k == 'aggregator')},
dynamic=False)
op = transform.instance(**op_params)
op._precomputed = self._precomputed
Expand Down
15 changes: 12 additions & 3 deletions holoviews/tests/operation/testdatashader.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,8 +320,8 @@ def test_rasterize_trimesh_no_vdims(self):
vertices = [(0., 0.), (0., 1.), (1., 0), (1, 1)]
trimesh = TriMesh((simplices, vertices))
img = rasterize(trimesh, width=3, height=3, dynamic=False)
image = Image(np.array([[2, 1, 2], [1, 2, 1], [2, 1, 2]]),
bounds=(0, 0, 1, 1), vdims='Count')
image = Image(np.array([[True, True, True], [True, True, True], [True, True, True]]),
bounds=(0, 0, 1, 1), vdims='Any')
self.assertEqual(img, image)

def test_rasterize_trimesh_no_vdims_zero_range(self):
Expand All @@ -330,7 +330,16 @@ def test_rasterize_trimesh_no_vdims_zero_range(self):
trimesh = TriMesh((simplices, vertices))
img = rasterize(trimesh, height=2, x_range=(0, 0), dynamic=False)
image = Image(([], [0.25, 0.75], np.zeros((2, 0))),
bounds=(0, 0, 0, 1), xdensity=1, vdims='Count')
bounds=(0, 0, 0, 1), xdensity=1, vdims='Any')
self.assertEqual(img, image)

def test_rasterize_trimesh_with_vdims_as_wireframe(self):
simplices = [(0, 1, 2, 0.5), (3, 2, 1, 1.5)]
vertices = [(0., 0.), (0., 1.), (1., 0), (1, 1)]
trimesh = TriMesh((simplices, vertices), vdims=['z'])
img = rasterize(trimesh, width=3, height=3, aggregator='any', interpolation=None, dynamic=False)
image = Image(np.array([[True, True, True], [True, True, True], [True, True, True]]),
bounds=(0, 0, 1, 1), vdims='Any')
self.assertEqual(img, image)

def test_rasterize_trimesh(self):
Expand Down

0 comments on commit fe4c8bb

Please sign in to comment.