Skip to content

Commit

Permalink
Cupy implementation of eq_hist
Browse files Browse the repository at this point in the history
  • Loading branch information
ianthomas23 committed Oct 5, 2022
1 parent a1de2d5 commit 2df3e34
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 8 deletions.
5 changes: 5 additions & 0 deletions datashader/tests/test_transfer_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,11 @@ def test_shade(agg, attr, span):

img = tf.shade(x, cmap=cmap, how='eq_hist', rescale_discrete_levels=True)
sol = tf.Image(eq_hist_sol_rescale_discrete_levels[attr], coords=coords, dims=dims)
if cupy and attr=='a' and isinstance(agg.a.data, cupy.ndarray):
# cupy eq_hist has slightly different numerics hence slightly different RGBA results
sol = sol.copy(deep=True)
sol[2, 0] = sol[2, 0] - 0x100

assert_eq_xr(img, sol)

img = tf.shade(x, cmap=cmap,
Expand Down
19 changes: 11 additions & 8 deletions datashader/transfer_functions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,36 +160,39 @@ def eq_hist(data, mask=None, nbins=256*256):
"""
if cupy and isinstance(data, cupy.ndarray):
from._cuda_utils import interp
array_module = cupy
elif not isinstance(data, np.ndarray):
raise TypeError("data must be an ndarray")
else:
interp = np.interp
array_module = np

data2 = data if mask is None else data[~mask]

# Run more accurate value counting if data is of boolean or integer type
# and unique value array is smaller than nbins.
if data2.dtype == bool or (np.issubdtype(data2.dtype, np.integer) and data2.ptp() < nbins):
values, counts = np.unique(data2, return_counts=True)
vmin, vmax = values[0], values[-1]
if data2.dtype == bool or (array_module.issubdtype(data2.dtype, array_module.integer) and
data2.ptp() < nbins):
values, counts = array_module.unique(data2, return_counts=True)
vmin, vmax = values[0].item(), values[-1].item() # Convert from arrays to scalars.
interval = vmax-vmin
bin_centers = np.arange(vmin, vmax+1)
hist = np.zeros(interval+1, dtype='uint64')
bin_centers = array_module.arange(vmin, vmax+1)
hist = array_module.zeros(interval+1, dtype='uint64')
hist[values-vmin] = counts
discrete_levels = len(values)
else:
hist, bin_edges = np.histogram(data2, bins=nbins)
hist, bin_edges = array_module.histogram(data2, bins=nbins)
bin_centers = (bin_edges[:-1] + bin_edges[1:]) / 2
keep_mask = (hist > 0)
discrete_levels = np.count_nonzero(keep_mask)
discrete_levels = array_module.count_nonzero(keep_mask)
if discrete_levels != len(hist):
# Remove empty histogram bins.
hist = hist[keep_mask]
bin_centers = bin_centers[keep_mask]
cdf = hist.cumsum()
cdf = cdf / float(cdf[-1])
out = interp(data, bin_centers, cdf).reshape(data.shape)
return out if mask is None else np.where(mask, np.nan, out), discrete_levels
return out if mask is None else array_module.where(mask, array_module.nan, out), discrete_levels



Expand Down

0 comments on commit 2df3e34

Please sign in to comment.