Skip to content

Commit

Permalink
Merge pull request #121 from jcrist/eq_hist_fix
Browse files Browse the repository at this point in the history
Fix for integer eq_hist
  • Loading branch information
jcrist committed Mar 23, 2016
2 parents 2b10738 + 497ae83 commit b26a5d4
Show file tree
Hide file tree
Showing 2 changed files with 60 additions and 40 deletions.
41 changes: 25 additions & 16 deletions datashader/tests/test_transfer_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,15 @@
agg = xr.Dataset(dict(a=s_a, b=s_b, c=s_c))


eq_hist_sol = {'a': np.array([[0, 4291543295, 4288846335],
[4286149631, 0, 4283518207],
[4280821503, 4278190335, 0]], dtype='u4'),
'b': np.array([[0, 4291543295, 4288846335],
[4286609919, 0, 4283518207],
[4281281791, 4278190335, 0]], dtype='u4')}
eq_hist_sol['c'] = eq_hist_sol['b']


@pytest.mark.parametrize(['attr'], ['a', 'b', 'c'])
def test_interpolate(attr):
x = getattr(agg, attr)
Expand All @@ -48,12 +57,10 @@ def test_interpolate(attr):
sol = xr.DataArray(sol, coords=coords, dims=dims)
assert img.equals(sol)
img = tf.interpolate(x, cmap=cmap, how='eq_hist')
sol = np.array([[0, 4291543295, 4288846335],
[4286609919, 0, 4283518207],
[4281281791, 4278190335, 0]], dtype='u4')
sol = xr.DataArray(sol, coords=coords, dims=dims)
sol = xr.DataArray(eq_hist_sol[attr], coords=coords, dims=dims)
assert img.equals(sol)
img = tf.interpolate(x, cmap=cmap, how=lambda x: x ** 2)
img = tf.interpolate(x, cmap=cmap,
how=lambda x, mask: np.where(mask, np.nan, x ** 2))
sol = np.array([[0, 4291543295, 4291148543],
[4290030335, 0, 4285557503],
[4282268415, 4278190335, 0]], dtype='u4')
Expand Down Expand Up @@ -97,24 +104,25 @@ def test_colorize():
colors = [(255, 0, 0), '#0000FF', 'orange']

img = tf.colorize(cat_agg, colors, how='log')
sol = np.array([[3137273856, 2449494783],
[4266997674, 3841982719]])
sol = np.array([[2583625728, 335565567],
[4283774890, 3707764991]], dtype='u4')
sol = tf.Image(sol, coords=coords, dims=dims)
assert img.equals(sol)
colors = dict(zip('abc', colors))
img = tf.colorize(cat_agg, colors, how='cbrt')
sol = np.array([[3070164992, 2499826431],
[4283774890, 3774873855]])
sol = np.array([[2650734592, 335565567],
[4283774890, 3657433343]], dtype='u4')
sol = tf.Image(sol, coords=coords, dims=dims)
assert img.equals(sol)
img = tf.colorize(cat_agg, colors, how='linear')
sol = np.array([[1660878848, 989876991],
[4283774890, 2952790271]])
sol = np.array([[1140785152, 335565567],
[4283774890, 2701132031]], dtype='u4')
sol = tf.Image(sol, coords=coords, dims=dims)
assert img.equals(sol)
img = tf.colorize(cat_agg, colors, how=lambda x: x ** 2)
sol = np.array([[788463616, 436228863],
[4283774890, 2080375039]])
img = tf.colorize(cat_agg, colors,
how=lambda x, m: np.where(m, np.nan, x) ** 2)
sol = np.array([[503250944, 335565567],
[4283774890, 1744830719]], dtype='u4')
sol = tf.Image(sol, coords=coords, dims=dims)
assert img.equals(sol)

Expand Down Expand Up @@ -262,9 +270,10 @@ def test_eq_hist():
data = np.random.normal(size=300**2)
data[np.random.randint(300**2, size=100)] = np.nan
data = (data - np.nanmin(data)).reshape((300, 300))
eq = tf.eq_hist(data)
mask = np.isnan(data)
eq = tf.eq_hist(data, mask)
check_eq_hist_cdf_slope(eq)
assert (np.isnan(eq) == np.isnan(data)).all()
assert (np.isnan(eq) == mask).all()
# Integer
data = np.random.normal(scale=100, size=(300, 300)).astype('i8')
data = data - data.min()
Expand Down
59 changes: 35 additions & 24 deletions datashader/transfer_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,14 +60,16 @@ def stack(*imgs, **kwargs):
return Image(out, coords=imgs[0].coords, dims=imgs[0].dims)


def eq_hist(data, nbins=256):
def eq_hist(data, mask=None, nbins=256):
"""Return a numpy array after histogram equalization.
For use in `interpolate`.
Parameters
----------
data : ndarray
mask : ndarray, optional
Boolean array of missing points. Where True, the output will be `NaN`.
nbins : int, optional
Number of bins to use. Note that this argument is ignored for integer
arrays, which bin by the integer values directly.
Expand All @@ -82,22 +84,24 @@ def eq_hist(data, nbins=256):
"""
if not isinstance(data, np.ndarray):
raise TypeError("data must be np.ndarray")
if np.issubdtype(data.dtype, np.integer):
hist = np.bincount(data.ravel())
data2 = data if mask is None else data[~mask]
if np.issubdtype(data2.dtype, np.integer):
hist = np.bincount(data2.ravel())
bin_centers = np.arange(len(hist))
idx = np.nonzero(hist)[0][0]
hist, bin_centers = hist[idx:], bin_centers[idx:]
else:
hist, bin_edges = np.histogram(data[~np.isnan(data)], bins=nbins)
hist, bin_edges = np.histogram(data2, bins=nbins)
bin_centers = (bin_edges[:-1] + bin_edges[1:]) / 2
cdf = hist.cumsum()
cdf = cdf / float(cdf[-1])
return np.interp(data.flat, bin_centers, cdf).reshape(data.shape)
out = np.interp(data.flat, bin_centers, cdf).reshape(data.shape)
return out if mask is None else np.where(mask, np.nan, out)


_interpolate_lookup = {'log': np.log1p,
'cbrt': lambda x: x ** (1/3.),
'linear': lambda x: x,
_interpolate_lookup = {'log': lambda d, m: np.log1p(np.where(m, np.nan, d)),
'cbrt': lambda d, m: np.where(m, np.nan, d)**(1/3.),
'linear': lambda d, m: np.where(m, np.nan, d),
'eq_hist': eq_hist}


Expand Down Expand Up @@ -128,9 +132,10 @@ def interpolate(agg, low=None, high=None, cmap=None, how='cbrt'):
Default is `["lightblue", "darkblue"]`
how : str or callable, optional
The interpolation method to use. Valid strings are 'cbrt' [default],
'log', 'linear', and 'eq_hist'. Callables take a 2-dimensional array of
magnitudes at each pixel, and should return a numeric array of the same
shape.
'log', 'linear', and 'eq_hist'. Callables take 2 arguments - a
2-dimensional array of magnitudes at each pixel, and a boolean mask
array indicating missingness. They should return a numeric array of the
same shape, with `NaN`s where the mask was True.
"""
if not isinstance(agg, xr.DataArray):
raise TypeError("agg must be instance of DataArray")
Expand All @@ -147,12 +152,13 @@ def interpolate(agg, low=None, high=None, cmap=None, how='cbrt'):
"Instead use `cmap=[low, high]`")
warnings.warn(w)
cmap = [low or cmap[0], high or cmap[1]]
offset = agg.min()
if offset == 0:
agg = agg.where(agg > 0)
offset = agg.min()
how = _normalize_interpolate_how(how)
data = how(agg.data - offset.data)
offset = agg.min().data
mask = agg.isnull()
if offset == 0:
mask = mask | (agg <= 0)
offset = agg.data[agg.data > 0].min()
data = how(agg.data - offset, mask.data)
span = [np.nanmin(data), np.nanmax(data)]
if isinstance(cmap, list):
rspan, gspan, bspan = np.array(list(zip(*map(rgb, cmap))))
Expand Down Expand Up @@ -185,9 +191,10 @@ def colorize(agg, color_key, how='cbrt', min_alpha=20):
record fields.
how : str or callable, optional
The interpolation method to use. Valid strings are 'cbrt' [default],
'log', and 'linear'. Callables take a 2-dimensional array of
magnitudes at each pixel, and should return a numeric array of the same
shape.
'log', and 'linear'. Callables take 2 arguments - a 2-dimensional
array of magnitudes at each pixel, and a boolean mask array indicating
missingness. They should return a numeric array of the same shape, with
`NaN`s where the mask was True.
min_alpha : float, optional
The minimum alpha value to use for non-empty pixels, in [0, 255].
"""
Expand All @@ -209,11 +216,15 @@ def colorize(agg, color_key, how='cbrt', min_alpha=20):
r = (data.dot(rs)/total).astype(np.uint8)
g = (data.dot(gs)/total).astype(np.uint8)
b = (data.dot(bs)/total).astype(np.uint8)
a = _normalize_interpolate_how(how)(total)
a = ((255 - min_alpha) * a/a.max() + min_alpha).astype(np.uint8)
white = (total == 0)
r[white] = g[white] = b[white] = 255
a[white] = 0
offset = total.min()
mask = np.isnan(total)
if offset == 0:
mask = mask | (total <= 0)
offset = total[total > 0].min()
a = _normalize_interpolate_how(how)(total - offset, mask)
a = np.interp(a, [np.nanmin(a), np.nanmax(a)],
[min_alpha, 255], left=0, right=255).astype(np.uint8)
r[mask] = g[mask] = b[mask] = 255
return Image(np.dstack([r, g, b, a]).view(np.uint32).reshape(a.shape),
dims=agg.dims[:-1], coords=list(agg.coords.values())[:-1])

Expand Down

0 comments on commit b26a5d4

Please sign in to comment.