From 8dccea21c9a8306528e21465dd3aa03b373bffc1 Mon Sep 17 00:00:00 2001 From: Julia Signell Date: Fri, 21 Dec 2018 09:33:24 -0500 Subject: [PATCH] Allow the use of xarray objects as the input to datashader 2D functions (#675) --- datashader/core.py | 50 ++++++++++++++++++++----------- datashader/tests/test_xarray.py | 53 +++++++++++++++++++++++++++++++++ 2 files changed, 85 insertions(+), 18 deletions(-) create mode 100644 datashader/tests/test_xarray.py diff --git a/datashader/core.py b/datashader/core.py index 708012d3f..40c40beab 100644 --- a/datashader/core.py +++ b/datashader/core.py @@ -147,7 +147,7 @@ def points(self, source, x, y, agg=None): Parameters ---------- - source : pandas.DataFrame, dask.DataFrame + source : pandas.DataFrame, dask.DataFrame, or xarray.DataArray/Dataset The input datasource. x, y : str Column names for the x and y coordinates of each point. @@ -174,7 +174,7 @@ def line(self, source, x, y, agg=None): Parameters ---------- - source : pandas.DataFrame, dask.DataFrame + source : pandas.DataFrame, dask.DataFrame, or xarray.DataArray/Dataset The input datasource. x, y : str Column names for the x and y coordinates of each vertex. @@ -234,7 +234,7 @@ def trimesh(self, vertices, simplices, mesh=None, agg=None, interp=True, interpo Method to use for interpolation between specified values. ``nearest`` means to use a single value for the whole triangle, and ``linear`` means to do bilinear interpolation of the pixels within each - triangle (a weighted average of the vertex values). For + triangle (a weighted average of the vertex values). For backwards compatibility, also accepts ``interp=True`` for ``linear`` and ``interp=False`` for ``nearest``. """ @@ -252,7 +252,7 @@ def trimesh(self, vertices, simplices, mesh=None, agg=None, interp=True, interpo interp = False else: raise ValueError('Invalid interpolate method: options include {}'.format(['linear','nearest'])) - + # Validation is done inside the [pd]d_mesh utility functions if source is None: source = create_mesh(vertices, simplices) @@ -319,7 +319,7 @@ def raster(self, # For backwards compatibility if agg is None: agg=downsample_method if interpolate is None: interpolate=upsample_method - + upsample_methods = ['nearest','linear'] downsample_methods = {'first':'first', rd.first:'first', @@ -387,7 +387,7 @@ def raster(self, if self.x_range is None: self.x_range = (left,right) if self.y_range is None: self.y_range = (bottom,top) - + # window coordinates xmin = max(self.x_range[0], left) ymin = max(self.y_range[0], bottom) @@ -505,23 +505,22 @@ def bypixel(source, canvas, glyph, agg): glyph : Glyph agg : Reduction """ + if isinstance(source, DataArray): + if not source.name: + source.name = 'value' + source = source.reset_coords() + if isinstance(source, Dataset): + columns = list(source.coords.keys()) + list(source.data_vars.keys()) + cols_to_keep = _cols_to_keep(columns, glyph, agg) + source = source.drop([col for col in columns if col not in cols_to_keep]) + source = source.to_dask_dataframe() + if isinstance(source, pd.DataFrame): # Avoid datashape.Categorical instantiation bottleneck # by only retaining the necessary columns: # https://github.com/bokeh/datashader/issues/396 # Preserve column ordering without duplicates - cols_to_keep = OrderedDict({col: False for col in source.columns}) - cols_to_keep[glyph.x] = True - cols_to_keep[glyph.y] = True - if hasattr(glyph, 'z'): - cols_to_keep[glyph.z[0]] = True - if hasattr(agg, 'values'): - for subagg in agg.values: - if subagg.column is not None: - cols_to_keep[subagg.column] = True - elif agg.column is not None: - cols_to_keep[agg.column] = True - cols_to_keep = [col for col, keepit in cols_to_keep.items() if keepit] + cols_to_keep = _cols_to_keep(source.columns, glyph, agg) if len(cols_to_keep) < len(source.columns): source = source[cols_to_keep] dshape = dshape_from_pandas(source) @@ -540,4 +539,19 @@ def bypixel(source, canvas, glyph, agg): return bypixel.pipeline(source, schema, canvas, glyph, agg) +def _cols_to_keep(columns, glyph, agg): + cols_to_keep = OrderedDict({col: False for col in columns}) + cols_to_keep[glyph.x] = True + cols_to_keep[glyph.y] = True + if hasattr(glyph, 'z'): + cols_to_keep[glyph.z[0]] = True + if hasattr(agg, 'values'): + for subagg in agg.values: + if subagg.column is not None: + cols_to_keep[subagg.column] = True + elif agg.column is not None: + cols_to_keep[agg.column] = True + return [col for col, keepit in cols_to_keep.items() if keepit] + + bypixel.pipeline = Dispatcher() diff --git a/datashader/tests/test_xarray.py b/datashader/tests/test_xarray.py new file mode 100644 index 000000000..ce24d22f9 --- /dev/null +++ b/datashader/tests/test_xarray.py @@ -0,0 +1,53 @@ +import numpy as np +import xarray as xr + +import datashader as ds + +import pytest + + +xda = xr.DataArray(data=np.array(([1.] * 10 + [10] * 10)), + dims=('record'), + coords={'x': xr.DataArray(np.array(([0.] * 10 + [1] * 10)), dims=('record')), + 'y': xr.DataArray(np.array(([0.] * 5 + [1] * 5 + [0] * 5 + [1] * 5)), dims=('record')), + 'i32': xr.DataArray(np.arange(20, dtype='i4'), dims=('record')), + 'i64': xr.DataArray(np.arange(20, dtype='i8'), dims=('record')), + 'f32': xr.DataArray(np.arange(20, dtype='f4'), dims=('record')), + 'f64': xr.DataArray(np.arange(20, dtype='f8'), dims=('record')), + }) +xda.f32[2] = np.nan +xda.f64[2] = np.nan +xds = xda.to_dataset(name='value').reset_coords(names=['i32', 'i64']) + +xdda = xda.chunk(chunks=5) +xdds = xds.chunk(chunks=5) + +c = ds.Canvas(plot_width=2, plot_height=2, x_range=(0, 1), y_range=(0, 1)) + +axis = ds.core.LinearAxis() +lincoords = axis.compute_index(axis.compute_scale_and_translate((0, 1), 2), 2) +coords = [lincoords, lincoords] +dims = ['y', 'x'] + + +def assert_eq(agg, b): + assert agg.equals(b) + + +@pytest.mark.parametrize("source", [ + (xda), (xdda), (xds), (xdds), +]) +def test_count(source): + out = xr.DataArray(np.array([[5, 5], [5, 5]], dtype='i4'), + coords=coords, dims=dims) + assert_eq(c.points(source, 'x', 'y', ds.count('i32')), out) + assert_eq(c.points(source, 'x', 'y', ds.count('i64')), out) + assert_eq(c.points(source, 'x', 'y', ds.count()), out) + assert_eq(c.points(source, 'x', 'y', ds.count('value')), out) + out = xr.DataArray(np.array([[4, 5], [5, 5]], dtype='i4'), + coords=coords, dims=dims) + assert_eq(c.points(source, 'x', 'y', ds.count('f32')), out) + assert_eq(c.points(source, 'x', 'y', ds.count('f64')), out) + + +