Skip to content

Commit

Permalink
Add internal _sum_zero reduction to compute sum from array of zeros (…
Browse files Browse the repository at this point in the history
…rather than nans)

Prior structure was not compatible with the numba.cuda.atomic.add operation
  • Loading branch information
jonmmease committed Oct 3, 2019
1 parent 7a10a49 commit cf7f5c8
Showing 1 changed file with 48 additions and 8 deletions.
56 changes: 48 additions & 8 deletions datashader/reductions.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import absolute_import, division, print_function

from math import isnan
import numpy as np
from datashape import dshape, isnumeric, Record, Option
from datashape import coretypes as ct
Expand Down Expand Up @@ -177,9 +178,11 @@ def _finalize(bases, **kwargs):
return xr.DataArray(bases[0], **kwargs)


class sum(FloatingReduction):
class _sum_zero(FloatingReduction):
"""Sum of all elements in ``column``.
Elements of resulting aggregate are zero if they are not updated.
Parameters
----------
column : str
Expand All @@ -202,6 +205,40 @@ def _append_float_field(x, y, agg, field):
if not isnan(field):
agg[y, x] += field

@staticmethod
def _combine(aggs):
return aggs.sum(axis=0, dtype='f8')

class sum(FloatingReduction):
"""Sum of all elements in ``column``.
Elements of resulting aggregate are nan if they are not updated.
Parameters
----------
column : str
Name of the column to aggregate over. Column data type must be numeric.
``NaN`` values in the column are skipped.
"""
_dshape = dshape(Option(ct.float64))

@staticmethod
@ngjit
def _append_int_field(x, y, agg, field):
if np.isnan(agg[y, x]):
agg[y, x] = field
else:
agg[y, x] += field

@staticmethod
@ngjit
def _append_float_field(x, y, agg, field):
if not np.isnan(field):
if np.isnan(agg[y, x]):
agg[y, x] = field
else:
agg[y, x] += field

@staticmethod
def _combine(aggs):
missing_vals = np.isnan(aggs)
Expand Down Expand Up @@ -229,7 +266,10 @@ def _create(shape, array_module):

@property
def _temps(self):
return (sum(self.column), count(self.column))
return (_sum_zero(self.column), count(self.column))

def _build_append(self, dshape, schema):
return super(m2, self)._build_append(dshape, schema)

@staticmethod
@ngjit
Expand Down Expand Up @@ -368,13 +408,13 @@ class mean(Reduction):

@property
def _bases(self):
return (sum(self.column), count(self.column))
return (_sum_zero(self.column), count(self.column))

@staticmethod
def _finalize(bases, **kwargs):
sums, counts = bases
with np.errstate(divide='ignore', invalid='ignore'):
x = sums/counts
x = np.where(counts > 0, sums/counts, np.nan)
return xr.DataArray(x, **kwargs)


Expand All @@ -391,13 +431,13 @@ class var(Reduction):

@property
def _bases(self):
return (sum(self.column), count(self.column), m2(self.column))
return (_sum_zero(self.column), count(self.column), m2(self.column))

@staticmethod
def _finalize(bases, **kwargs):
sums, counts, m2s = bases
with np.errstate(divide='ignore', invalid='ignore'):
x = m2s/counts
x = np.where(counts > 0, m2s / counts, np.nan)
return xr.DataArray(x, **kwargs)


Expand All @@ -414,13 +454,13 @@ class std(Reduction):

@property
def _bases(self):
return (sum(self.column), count(self.column), m2(self.column))
return (_sum_zero(self.column), count(self.column), m2(self.column))

@staticmethod
def _finalize(bases, **kwargs):
sums, counts, m2s = bases
with np.errstate(divide='ignore', invalid='ignore'):
x = np.sqrt(m2s/counts)
x = np.where(counts > 0, np.sqrt(m2s / counts), np.nan)
return xr.DataArray(x, **kwargs)


Expand Down

0 comments on commit cf7f5c8

Please sign in to comment.