Skip to content

Commit

Permalink
Allow the use of xarray objects as the input to datashader 2D functio…
Browse files Browse the repository at this point in the history
…ns (#675)
  • Loading branch information
jsignell authored and philippjfr committed Dec 21, 2018
1 parent 14aa424 commit 8dccea2
Show file tree
Hide file tree
Showing 2 changed files with 85 additions and 18 deletions.
50 changes: 32 additions & 18 deletions datashader/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ def points(self, source, x, y, agg=None):
Parameters
----------
source : pandas.DataFrame, dask.DataFrame
source : pandas.DataFrame, dask.DataFrame, or xarray.DataArray/Dataset
The input datasource.
x, y : str
Column names for the x and y coordinates of each point.
Expand All @@ -174,7 +174,7 @@ def line(self, source, x, y, agg=None):
Parameters
----------
source : pandas.DataFrame, dask.DataFrame
source : pandas.DataFrame, dask.DataFrame, or xarray.DataArray/Dataset
The input datasource.
x, y : str
Column names for the x and y coordinates of each vertex.
Expand Down Expand Up @@ -234,7 +234,7 @@ def trimesh(self, vertices, simplices, mesh=None, agg=None, interp=True, interpo
Method to use for interpolation between specified values. ``nearest``
means to use a single value for the whole triangle, and ``linear``
means to do bilinear interpolation of the pixels within each
triangle (a weighted average of the vertex values). For
triangle (a weighted average of the vertex values). For
backwards compatibility, also accepts ``interp=True`` for ``linear``
and ``interp=False`` for ``nearest``.
"""
Expand All @@ -252,7 +252,7 @@ def trimesh(self, vertices, simplices, mesh=None, agg=None, interp=True, interpo
interp = False
else:
raise ValueError('Invalid interpolate method: options include {}'.format(['linear','nearest']))

# Validation is done inside the [pd]d_mesh utility functions
if source is None:
source = create_mesh(vertices, simplices)
Expand Down Expand Up @@ -319,7 +319,7 @@ def raster(self,
# For backwards compatibility
if agg is None: agg=downsample_method
if interpolate is None: interpolate=upsample_method

upsample_methods = ['nearest','linear']

downsample_methods = {'first':'first', rd.first:'first',
Expand Down Expand Up @@ -387,7 +387,7 @@ def raster(self,

if self.x_range is None: self.x_range = (left,right)
if self.y_range is None: self.y_range = (bottom,top)

# window coordinates
xmin = max(self.x_range[0], left)
ymin = max(self.y_range[0], bottom)
Expand Down Expand Up @@ -505,23 +505,22 @@ def bypixel(source, canvas, glyph, agg):
glyph : Glyph
agg : Reduction
"""
if isinstance(source, DataArray):
if not source.name:
source.name = 'value'
source = source.reset_coords()
if isinstance(source, Dataset):
columns = list(source.coords.keys()) + list(source.data_vars.keys())
cols_to_keep = _cols_to_keep(columns, glyph, agg)
source = source.drop([col for col in columns if col not in cols_to_keep])
source = source.to_dask_dataframe()

if isinstance(source, pd.DataFrame):
# Avoid datashape.Categorical instantiation bottleneck
# by only retaining the necessary columns:
# https://github.com/bokeh/datashader/issues/396
# Preserve column ordering without duplicates
cols_to_keep = OrderedDict({col: False for col in source.columns})
cols_to_keep[glyph.x] = True
cols_to_keep[glyph.y] = True
if hasattr(glyph, 'z'):
cols_to_keep[glyph.z[0]] = True
if hasattr(agg, 'values'):
for subagg in agg.values:
if subagg.column is not None:
cols_to_keep[subagg.column] = True
elif agg.column is not None:
cols_to_keep[agg.column] = True
cols_to_keep = [col for col, keepit in cols_to_keep.items() if keepit]
cols_to_keep = _cols_to_keep(source.columns, glyph, agg)
if len(cols_to_keep) < len(source.columns):
source = source[cols_to_keep]
dshape = dshape_from_pandas(source)
Expand All @@ -540,4 +539,19 @@ def bypixel(source, canvas, glyph, agg):
return bypixel.pipeline(source, schema, canvas, glyph, agg)


def _cols_to_keep(columns, glyph, agg):
cols_to_keep = OrderedDict({col: False for col in columns})
cols_to_keep[glyph.x] = True
cols_to_keep[glyph.y] = True
if hasattr(glyph, 'z'):
cols_to_keep[glyph.z[0]] = True
if hasattr(agg, 'values'):
for subagg in agg.values:
if subagg.column is not None:
cols_to_keep[subagg.column] = True
elif agg.column is not None:
cols_to_keep[agg.column] = True
return [col for col, keepit in cols_to_keep.items() if keepit]


bypixel.pipeline = Dispatcher()
53 changes: 53 additions & 0 deletions datashader/tests/test_xarray.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
import numpy as np
import xarray as xr

import datashader as ds

import pytest


xda = xr.DataArray(data=np.array(([1.] * 10 + [10] * 10)),
dims=('record'),
coords={'x': xr.DataArray(np.array(([0.] * 10 + [1] * 10)), dims=('record')),
'y': xr.DataArray(np.array(([0.] * 5 + [1] * 5 + [0] * 5 + [1] * 5)), dims=('record')),
'i32': xr.DataArray(np.arange(20, dtype='i4'), dims=('record')),
'i64': xr.DataArray(np.arange(20, dtype='i8'), dims=('record')),
'f32': xr.DataArray(np.arange(20, dtype='f4'), dims=('record')),
'f64': xr.DataArray(np.arange(20, dtype='f8'), dims=('record')),
})
xda.f32[2] = np.nan
xda.f64[2] = np.nan
xds = xda.to_dataset(name='value').reset_coords(names=['i32', 'i64'])

xdda = xda.chunk(chunks=5)
xdds = xds.chunk(chunks=5)

c = ds.Canvas(plot_width=2, plot_height=2, x_range=(0, 1), y_range=(0, 1))

axis = ds.core.LinearAxis()
lincoords = axis.compute_index(axis.compute_scale_and_translate((0, 1), 2), 2)
coords = [lincoords, lincoords]
dims = ['y', 'x']


def assert_eq(agg, b):
assert agg.equals(b)


@pytest.mark.parametrize("source", [
(xda), (xdda), (xds), (xdds),
])
def test_count(source):
out = xr.DataArray(np.array([[5, 5], [5, 5]], dtype='i4'),
coords=coords, dims=dims)
assert_eq(c.points(source, 'x', 'y', ds.count('i32')), out)
assert_eq(c.points(source, 'x', 'y', ds.count('i64')), out)
assert_eq(c.points(source, 'x', 'y', ds.count()), out)
assert_eq(c.points(source, 'x', 'y', ds.count('value')), out)
out = xr.DataArray(np.array([[4, 5], [5, 5]], dtype='i4'),
coords=coords, dims=dims)
assert_eq(c.points(source, 'x', 'y', ds.count('f32')), out)
assert_eq(c.points(source, 'x', 'y', ds.count('f64')), out)



0 comments on commit 8dccea2

Please sign in to comment.