Skip to content

Commit

Permalink
Add support for fast QuadMesh rasterization (#4020)
Browse files Browse the repository at this point in the history
  • Loading branch information
philippjfr authored Oct 4, 2019
1 parent 47b189b commit ac6a590
Show file tree
Hide file tree
Showing 4 changed files with 67 additions and 11 deletions.
12 changes: 11 additions & 1 deletion holoviews/core/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -990,13 +990,17 @@ def to(self):
return self._conversion_interface(self)


def clone(self, data=None, shared_data=True, new_type=None, *args, **overrides):
def clone(self, data=None, shared_data=True, new_type=None, link=True,
*args, **overrides):
"""Clones the object, overriding data and parameters.
Args:
data: New data replacing the existing data
shared_data (bool, optional): Whether to use existing data
new_type (optional): Type to cast object to
link (bool, optional): Whether clone should be linked
Determines whether Streams and Links attached to
original object will be inherited.
*args: Additional arguments to pass to constructor
**overrides: New keyword arguments to pass to constructor
Expand All @@ -1010,6 +1014,12 @@ def clone(self, data=None, shared_data=True, new_type=None, *args, **overrides):
if data is None:
overrides['_validate_vdims'] = False

# Allows datatype conversions
if shared_data:
data = self
if link:
overrides['plot_id'] = self._plot_id

if 'dataset' not in overrides:
overrides['dataset'] = self.dataset

Expand Down
19 changes: 12 additions & 7 deletions holoviews/core/data/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from .. import util
from ..element import Element
from ..ndmapping import OrderedDict, NdMapping
from ..ndmapping import NdMapping


def get_array_types():
Expand Down Expand Up @@ -225,14 +225,19 @@ def initialize(cls, eltype, data, kdims, vdims, datatype=None):
if not datatype:
datatype = eltype.datatype

if data.interface.datatype in datatype and data.interface.datatype in eltype.datatype:
interface = data.interface
if interface.datatype in datatype and interface.datatype in eltype.datatype:
data = data.data
elif data.interface.gridded and any(cls.interfaces[dt].gridded for dt in datatype):
gridded = OrderedDict([(kd.name, data.dimension_values(kd.name, expanded=False))
for kd in data.kdims])
elif interface.gridded and any(cls.interfaces[dt].gridded for dt in datatype):
new_data = []
for kd in data.kdims:
irregular = interface.irregular(data, kd)
coords = data.dimension_values(kd.name, expanded=irregular,
flat=not irregular)
new_data.append(coords)
for vd in data.vdims:
gridded[vd.name] = data.dimension_values(vd, flat=False)
data = tuple(gridded.values())
new_data.append(interface.values(data, vd, flat=False, compute=False))
data = tuple(new_data)
else:
data = tuple(data.columns().values())
elif isinstance(data, Element):
Expand Down
43 changes: 42 additions & 1 deletion holoviews/operation/datashader.py
Original file line number Diff line number Diff line change
Expand Up @@ -1020,7 +1020,48 @@ class quadmesh_rasterize(trimesh_rasterize):
"""

def _precompute(self, element, agg):
return super(quadmesh_rasterize, self)._precompute(element.trimesh(), agg)
if ds_version <= '0.7.0':
return super(quadmesh_rasterize, self)._precompute(element.trimesh(), agg)

def _process(self, element, key=None):
if ds_version <= '0.7.0':
return super(quadmesh_rasterize, self)._process(element, key)

if element.interface.datatype != 'xarray':
element = element.clone(datatype=['xarray'])
data = element.data

x, y = element.kdims
agg_fn = self._get_aggregator(element)
info = self._get_sampling(element, x, y)
(x_range, y_range), (xs, ys), (width, height), (xtype, ytype) = info
if xtype == 'datetime':
data[x.name] = data[x.name].astype('datetime64[us]').astype('int64')
if ytype == 'datetime':
data[y.name] = data[y.name].astype('datetime64[us]').astype('int64')

# Compute bounds (converting datetimes)
((x0, x1), (y0, y1)), (xs, ys) = self._dt_transform(
x_range, y_range, xs, ys, xtype, ytype
)
params = dict(get_param_values(element), datatype=['xarray'],
bounds=(x0, y0, x1, y1))

if width == 0 or height == 0:
return self._empty_agg(element, x, y, width, height, xs, ys, agg_fn, **params)

cvs = ds.Canvas(plot_width=width, plot_height=height,
x_range=x_range, y_range=y_range)

vdim = getattr(agg_fn, 'column', element.vdims[0].name)
agg = cvs.quadmesh(data[vdim], x.name, y.name, agg_fn)
xdim, ydim = list(agg.dims)[:2][::-1]
if xtype == "datetime":
agg[xdim] = (agg[xdim]/1e3).astype('datetime64[us]')
if ytype == "datetime":
agg[ydim] = (agg[ydim]/1e3).astype('datetime64[us]')

return Image(agg, **params)



Expand Down
4 changes: 2 additions & 2 deletions holoviews/tests/operation/testdatashader.py
Original file line number Diff line number Diff line change
Expand Up @@ -626,14 +626,14 @@ def test_rasterize_trimesh_string_aggregator(self):
def test_rasterize_quadmesh(self):
qmesh = QuadMesh(([0, 1], [0, 1], np.array([[0, 1], [2, 3]])))
img = rasterize(qmesh, width=3, height=3, dynamic=False, aggregator=ds.mean('z'))
image = Image(np.array([[2., 3., np.NaN], [0, 1, np.NaN], [np.NaN, np.NaN, np.NaN]]),
image = Image(np.array([[2, 3, 3], [2, 3, 3], [0, 1, 1]]),
bounds=(-.5, -.5, 1.5, 1.5))
self.assertEqual(img, image)

def test_rasterize_quadmesh_string_aggregator(self):
qmesh = QuadMesh(([0, 1], [0, 1], np.array([[0, 1], [2, 3]])))
img = rasterize(qmesh, width=3, height=3, dynamic=False, aggregator='mean')
image = Image(np.array([[2., 3., np.NaN], [0, 1, np.NaN], [np.NaN, np.NaN, np.NaN]]),
image = Image(np.array([[2, 3, 3], [2, 3, 3], [0, 1, 1]]),
bounds=(-.5, -.5, 1.5, 1.5))
self.assertEqual(img, image)

Expand Down

0 comments on commit ac6a590

Please sign in to comment.