From ac6a590054a613a20af4d595ae47563916c663b8 Mon Sep 17 00:00:00 2001 From: Philipp Rudiger Date: Fri, 4 Oct 2019 19:10:10 +0200 Subject: [PATCH] Add support for fast QuadMesh rasterization (#4020) --- holoviews/core/data/__init__.py | 12 +++++- holoviews/core/data/interface.py | 19 +++++---- holoviews/operation/datashader.py | 43 ++++++++++++++++++++- holoviews/tests/operation/testdatashader.py | 4 +- 4 files changed, 67 insertions(+), 11 deletions(-) diff --git a/holoviews/core/data/__init__.py b/holoviews/core/data/__init__.py index 3dd6d9d4f3..ed021e9887 100644 --- a/holoviews/core/data/__init__.py +++ b/holoviews/core/data/__init__.py @@ -990,13 +990,17 @@ def to(self): return self._conversion_interface(self) - def clone(self, data=None, shared_data=True, new_type=None, *args, **overrides): + def clone(self, data=None, shared_data=True, new_type=None, link=True, + *args, **overrides): """Clones the object, overriding data and parameters. Args: data: New data replacing the existing data shared_data (bool, optional): Whether to use existing data new_type (optional): Type to cast object to + link (bool, optional): Whether clone should be linked + Determines whether Streams and Links attached to + original object will be inherited. *args: Additional arguments to pass to constructor **overrides: New keyword arguments to pass to constructor @@ -1010,6 +1014,12 @@ def clone(self, data=None, shared_data=True, new_type=None, *args, **overrides): if data is None: overrides['_validate_vdims'] = False + # Allows datatype conversions + if shared_data: + data = self + if link: + overrides['plot_id'] = self._plot_id + if 'dataset' not in overrides: overrides['dataset'] = self.dataset diff --git a/holoviews/core/data/interface.py b/holoviews/core/data/interface.py index 477c1b20f5..5f6d955836 100644 --- a/holoviews/core/data/interface.py +++ b/holoviews/core/data/interface.py @@ -7,7 +7,7 @@ from .. import util from ..element import Element -from ..ndmapping import OrderedDict, NdMapping +from ..ndmapping import NdMapping def get_array_types(): @@ -225,14 +225,19 @@ def initialize(cls, eltype, data, kdims, vdims, datatype=None): if not datatype: datatype = eltype.datatype - if data.interface.datatype in datatype and data.interface.datatype in eltype.datatype: + interface = data.interface + if interface.datatype in datatype and interface.datatype in eltype.datatype: data = data.data - elif data.interface.gridded and any(cls.interfaces[dt].gridded for dt in datatype): - gridded = OrderedDict([(kd.name, data.dimension_values(kd.name, expanded=False)) - for kd in data.kdims]) + elif interface.gridded and any(cls.interfaces[dt].gridded for dt in datatype): + new_data = [] + for kd in data.kdims: + irregular = interface.irregular(data, kd) + coords = data.dimension_values(kd.name, expanded=irregular, + flat=not irregular) + new_data.append(coords) for vd in data.vdims: - gridded[vd.name] = data.dimension_values(vd, flat=False) - data = tuple(gridded.values()) + new_data.append(interface.values(data, vd, flat=False, compute=False)) + data = tuple(new_data) else: data = tuple(data.columns().values()) elif isinstance(data, Element): diff --git a/holoviews/operation/datashader.py b/holoviews/operation/datashader.py index 2a5b9b9011..c9cee201a2 100644 --- a/holoviews/operation/datashader.py +++ b/holoviews/operation/datashader.py @@ -1020,7 +1020,48 @@ class quadmesh_rasterize(trimesh_rasterize): """ def _precompute(self, element, agg): - return super(quadmesh_rasterize, self)._precompute(element.trimesh(), agg) + if ds_version <= '0.7.0': + return super(quadmesh_rasterize, self)._precompute(element.trimesh(), agg) + + def _process(self, element, key=None): + if ds_version <= '0.7.0': + return super(quadmesh_rasterize, self)._process(element, key) + + if element.interface.datatype != 'xarray': + element = element.clone(datatype=['xarray']) + data = element.data + + x, y = element.kdims + agg_fn = self._get_aggregator(element) + info = self._get_sampling(element, x, y) + (x_range, y_range), (xs, ys), (width, height), (xtype, ytype) = info + if xtype == 'datetime': + data[x.name] = data[x.name].astype('datetime64[us]').astype('int64') + if ytype == 'datetime': + data[y.name] = data[y.name].astype('datetime64[us]').astype('int64') + + # Compute bounds (converting datetimes) + ((x0, x1), (y0, y1)), (xs, ys) = self._dt_transform( + x_range, y_range, xs, ys, xtype, ytype + ) + params = dict(get_param_values(element), datatype=['xarray'], + bounds=(x0, y0, x1, y1)) + + if width == 0 or height == 0: + return self._empty_agg(element, x, y, width, height, xs, ys, agg_fn, **params) + + cvs = ds.Canvas(plot_width=width, plot_height=height, + x_range=x_range, y_range=y_range) + + vdim = getattr(agg_fn, 'column', element.vdims[0].name) + agg = cvs.quadmesh(data[vdim], x.name, y.name, agg_fn) + xdim, ydim = list(agg.dims)[:2][::-1] + if xtype == "datetime": + agg[xdim] = (agg[xdim]/1e3).astype('datetime64[us]') + if ytype == "datetime": + agg[ydim] = (agg[ydim]/1e3).astype('datetime64[us]') + + return Image(agg, **params) diff --git a/holoviews/tests/operation/testdatashader.py b/holoviews/tests/operation/testdatashader.py index 8d016e64b2..99bfd38dc2 100644 --- a/holoviews/tests/operation/testdatashader.py +++ b/holoviews/tests/operation/testdatashader.py @@ -626,14 +626,14 @@ def test_rasterize_trimesh_string_aggregator(self): def test_rasterize_quadmesh(self): qmesh = QuadMesh(([0, 1], [0, 1], np.array([[0, 1], [2, 3]]))) img = rasterize(qmesh, width=3, height=3, dynamic=False, aggregator=ds.mean('z')) - image = Image(np.array([[2., 3., np.NaN], [0, 1, np.NaN], [np.NaN, np.NaN, np.NaN]]), + image = Image(np.array([[2, 3, 3], [2, 3, 3], [0, 1, 1]]), bounds=(-.5, -.5, 1.5, 1.5)) self.assertEqual(img, image) def test_rasterize_quadmesh_string_aggregator(self): qmesh = QuadMesh(([0, 1], [0, 1], np.array([[0, 1], [2, 3]]))) img = rasterize(qmesh, width=3, height=3, dynamic=False, aggregator='mean') - image = Image(np.array([[2., 3., np.NaN], [0, 1, np.NaN], [np.NaN, np.NaN, np.NaN]]), + image = Image(np.array([[2, 3, 3], [2, 3, 3], [0, 1, 1]]), bounds=(-.5, -.5, 1.5, 1.5)) self.assertEqual(img, image)