Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for fast QuadMesh rasterization #4020

Merged
merged 7 commits into from
Oct 4, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 11 additions & 1 deletion holoviews/core/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -990,13 +990,17 @@ def to(self):
return self._conversion_interface(self)


def clone(self, data=None, shared_data=True, new_type=None, *args, **overrides):
def clone(self, data=None, shared_data=True, new_type=None, link=True,
*args, **overrides):
"""Clones the object, overriding data and parameters.

Args:
data: New data replacing the existing data
shared_data (bool, optional): Whether to use existing data
new_type (optional): Type to cast object to
link (bool, optional): Whether clone should be linked
Determines whether Streams and Links attached to
original object will be inherited.
*args: Additional arguments to pass to constructor
**overrides: New keyword arguments to pass to constructor

Expand All @@ -1010,6 +1014,12 @@ def clone(self, data=None, shared_data=True, new_type=None, *args, **overrides):
if data is None:
overrides['_validate_vdims'] = False

# Allows datatype conversions
if shared_data:
data = self
if link:
overrides['plot_id'] = self._plot_id

if 'dataset' not in overrides:
overrides['dataset'] = self.dataset

Expand Down
19 changes: 12 additions & 7 deletions holoviews/core/data/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from .. import util
from ..element import Element
from ..ndmapping import OrderedDict, NdMapping
from ..ndmapping import NdMapping


def get_array_types():
Expand Down Expand Up @@ -225,14 +225,19 @@ def initialize(cls, eltype, data, kdims, vdims, datatype=None):
if not datatype:
datatype = eltype.datatype

if data.interface.datatype in datatype and data.interface.datatype in eltype.datatype:
interface = data.interface
if interface.datatype in datatype and interface.datatype in eltype.datatype:
data = data.data
elif data.interface.gridded and any(cls.interfaces[dt].gridded for dt in datatype):
gridded = OrderedDict([(kd.name, data.dimension_values(kd.name, expanded=False))
for kd in data.kdims])
elif interface.gridded and any(cls.interfaces[dt].gridded for dt in datatype):
new_data = []
for kd in data.kdims:
irregular = interface.irregular(data, kd)
coords = data.dimension_values(kd.name, expanded=irregular,
flat=not irregular)
new_data.append(coords)
for vd in data.vdims:
gridded[vd.name] = data.dimension_values(vd, flat=False)
data = tuple(gridded.values())
new_data.append(interface.values(data, vd, flat=False, compute=False))
data = tuple(new_data)
else:
data = tuple(data.columns().values())
elif isinstance(data, Element):
Expand Down
43 changes: 42 additions & 1 deletion holoviews/operation/datashader.py
Original file line number Diff line number Diff line change
Expand Up @@ -1017,7 +1017,48 @@ class quadmesh_rasterize(trimesh_rasterize):
"""

def _precompute(self, element, agg):
return super(quadmesh_rasterize, self)._precompute(element.trimesh(), agg)
if ds_version <= '0.7.0':
return super(quadmesh_rasterize, self)._precompute(element.trimesh(), agg)

def _process(self, element, key=None):
if ds_version <= '0.7.0':
return super(quadmesh_rasterize, self)._process(element, key)

if element.interface.datatype != 'xarray':
element = element.clone(datatype=['xarray'])
data = element.data

x, y = element.kdims
agg_fn = self._get_aggregator(element)
info = self._get_sampling(element, x, y)
(x_range, y_range), (xs, ys), (width, height), (xtype, ytype) = info
if xtype == 'datetime':
data[x.name] = data[x.name].astype('datetime64[us]').astype('int64')
if ytype == 'datetime':
data[y.name] = data[y.name].astype('datetime64[us]').astype('int64')

# Compute bounds (converting datetimes)
((x0, x1), (y0, y1)), (xs, ys) = self._dt_transform(
x_range, y_range, xs, ys, xtype, ytype
)
params = dict(get_param_values(element), datatype=['xarray'],
bounds=(x0, y0, x1, y1))

if width == 0 or height == 0:
return self._empty_agg(element, x, y, width, height, xs, ys, agg_fn, **params)

cvs = ds.Canvas(plot_width=width, plot_height=height,
x_range=x_range, y_range=y_range)

vdim = getattr(agg_fn, 'column', element.vdims[0].name)
agg = cvs.quadmesh(data[vdim], x.name, y.name, agg_fn)
xdim, ydim = list(agg.dims)[:2][::-1]
if xtype == "datetime":
agg[xdim] = (agg[xdim]/1e3).astype('datetime64[us]')
if ytype == "datetime":
agg[ydim] = (agg[ydim]/1e3).astype('datetime64[us]')

return Image(agg, **params)



Expand Down
4 changes: 2 additions & 2 deletions holoviews/tests/operation/testdatashader.py
Original file line number Diff line number Diff line change
Expand Up @@ -626,14 +626,14 @@ def test_rasterize_trimesh_string_aggregator(self):
def test_rasterize_quadmesh(self):
qmesh = QuadMesh(([0, 1], [0, 1], np.array([[0, 1], [2, 3]])))
img = rasterize(qmesh, width=3, height=3, dynamic=False, aggregator=ds.mean('z'))
image = Image(np.array([[2., 3., np.NaN], [0, 1, np.NaN], [np.NaN, np.NaN, np.NaN]]),
image = Image(np.array([[2, 3, 3], [2, 3, 3], [0, 1, 1]]),
bounds=(-.5, -.5, 1.5, 1.5))
self.assertEqual(img, image)

def test_rasterize_quadmesh_string_aggregator(self):
qmesh = QuadMesh(([0, 1], [0, 1], np.array([[0, 1], [2, 3]])))
img = rasterize(qmesh, width=3, height=3, dynamic=False, aggregator='mean')
image = Image(np.array([[2., 3., np.NaN], [0, 1, np.NaN], [np.NaN, np.NaN, np.NaN]]),
image = Image(np.array([[2, 3, 3], [2, 3, 3], [0, 1, 1]]),
bounds=(-.5, -.5, 1.5, 1.5))
self.assertEqual(img, image)

Expand Down