Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix regression: input arrays are not broadcast #61

Merged
merged 4 commits into from
Jun 14, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion ci/environment-3.7.yml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ dependencies:
- python=3.7
- xarray
- dask
- numpy=1.16
- numpy=1.17
- pytest
- pip
- pip:
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
"Topic :: Scientific/Engineering",
]

INSTALL_REQUIRES = ["xarray>=0.12.0", "dask", "numpy>=1.16"]
INSTALL_REQUIRES = ["xarray>=0.12.0", "dask", "numpy>=1.17"]
PYTHON_REQUIRES = ">=3.7"

DESCRIPTION = "Fast, flexible, label-aware histograms for numpy and xarray"
Expand Down
43 changes: 18 additions & 25 deletions xhistogram/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,11 +193,7 @@ def _bincount_2d_vectorized(


def _bincount(*all_arrays, weights=False, axis=None, bins=None, density=None):

# is this necessary?
all_arrays_broadcast = broadcast_arrays(*all_arrays)

a0 = all_arrays_broadcast[0]
a0 = all_arrays[0]

do_full_array = (axis is None) or (set(axis) == set(_range(a0.ndim)))

Expand Down Expand Up @@ -226,7 +222,7 @@ def reshape_input(a):
d = reshape(c, (new_dim_0, new_dim_1))
return d

all_arrays_reshaped = [reshape_input(a) for a in all_arrays_broadcast]
all_arrays_reshaped = [reshape_input(a) for a in all_arrays]

if weights:
weights_array = all_arrays_reshaped.pop()
Expand Down Expand Up @@ -257,8 +253,8 @@ def histogram(
Parameters
----------
args : array_like
Input data. The number of input arguments determines the dimensonality
of the histogram. For example, two arguments prodocue a 2D histogram.
Input data. The number of input arguments determines the dimensionality
of the histogram. For example, two arguments produce a 2D histogram.
All args must have the same size.
bins : int, str or numpy array or a list of ints, strs and/or arrays, optional
If a list, there should be one entry for each item in ``args``.
Expand Down Expand Up @@ -358,21 +354,20 @@ def histogram(

dtype = "i8" if not has_weights else weights.dtype

# here I am assuming all the arrays have the same shape
# probably needs to be generalized
input_indexes = [tuple(_range(a.ndim)) for a in all_arrays]
input_index = input_indexes[0]
assert all([ii == input_index for ii in input_indexes])
# Broadcast input arrays. Note that this dispatches to `dsa.broadcast_arrays` as necessary.
all_arrays = broadcast_arrays(*all_arrays)
# Since all arrays now have the same shape, just get the axes of the first.
input_axes = tuple(_range(all_arrays[0].ndim))

# Some sanity checks and format bins and range correctly
bins = _ensure_correctly_formatted_bins(bins, n_inputs)
range = _ensure_correctly_formatted_range(range, n_inputs)

# histogram_bin_edges trigges computation on dask arrays. It would be possible
# histogram_bin_edges triggers computation on dask arrays. It would be possible
# to write a version of this that doesn't trigger when `range` is provided, but
# for now let's just use np.histogram_bin_edges
if is_dask_array:
if not all([isinstance(b, np.ndarray) for b in bins]):
if not all(isinstance(b, np.ndarray) for b in bins):
raise TypeError(
"When using dask arrays, bins must be provided as numpy array(s) of edges"
)
Expand All @@ -382,11 +377,11 @@ def histogram(
]
bincount_kwargs = dict(weights=has_weights, axis=axis, bins=bins, density=density)

# keep these axes in the inputs
# remove these axes from the inputs
if axis is not None:
drop_axes = tuple([ii for ii in input_index if ii in axis])
drop_axes = tuple(axis)
else:
drop_axes = input_index
drop_axes = input_axes

if _any_dask_array(weights, *all_arrays):
# We should be able to just apply the bin_count function to every
Expand All @@ -405,16 +400,14 @@ def histogram(

adjust_chunks = {i: (lambda x: 1) for i in drop_axes}

new_axes = {
max(input_index) + 1 + i: axis_len
for i, axis_len in enumerate([len(bin) - 1 for bin in bins])
}
out_index = input_index + tuple(new_axes)
new_axes_start = max(input_axes) + 1
new_axes = {new_axes_start + i: len(bin) - 1 for i, bin in enumerate(bins)}
out_index = input_axes + tuple(new_axes)

blockwise_args = []
for arg in all_arrays:
blockwise_args.append(arg)
blockwise_args.append(input_index)
blockwise_args.append(input_axes)

bin_counts = dsa.blockwise(
_bincount,
Expand All @@ -432,7 +425,7 @@ def histogram(
bin_counts = _bincount(*all_arrays, **bincount_kwargs).squeeze(drop_axes)

if density:
# Normalise by dividing by bin counts and areas such that all the
# Normalize by dividing by bin counts and areas such that all the
# histogram data integrated over all dimensions = 1
bin_widths = [np.diff(b) for b in bins]
if n_inputs == 1:
Expand Down
34 changes: 31 additions & 3 deletions xhistogram/test/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,34 @@ def test_histogram_results_2d():
np.testing.assert_array_equal(hist, h)


@pytest.mark.parametrize("dask", [False, True])
def test_histogram_results_2d_broadcasting(dask):
nrows, ncols = 5, 20
data_a = np.random.randn(ncols)
data_b = np.random.randn(nrows, ncols)
nbins_a = 9
bins_a = np.linspace(-4, 4, nbins_a + 1)
nbins_b = 10
bins_b = np.linspace(-4, 4, nbins_b + 1)

if dask:
test_data_a = dsa.from_array(data_a, chunks=3)
test_data_b = dsa.from_array(data_b, chunks=(2, 7))
else:
test_data_a = data_a
test_data_b = data_b

h, _ = histogram(test_data_a, test_data_b, bins=[bins_a, bins_b])
assert h.shape == (nbins_a, nbins_b)

hist, _, _ = np.histogram2d(
np.broadcast_to(data_a, data_b.shape).ravel(),
data_b.ravel(),
bins=[bins_a, bins_b],
)
np.testing.assert_array_equal(hist, h)


def test_histogram_results_2d_density():
nrows, ncols = 5, 20
data_a = np.random.randn(nrows, ncols)
Expand Down Expand Up @@ -228,7 +256,7 @@ def test_histogram_shape(use_dask, block_size):


def test_histogram_dask():
""" Test that fails with dask arrays and inappropriate bins"""
"""Test that fails with dask arrays and inappropriate bins"""
shape = 10, 15, 12, 20
b = empty_dask_array(shape, chunks=(1,) + shape[1:])
histogram(b, bins=bins_arr) # Should work when bins is all numpy arrays
Expand All @@ -255,7 +283,7 @@ def test_histogram_dask():
],
)
def test_ensure_correctly_formatted_bins(in_out):
""" Test the helper function _ensure_correctly_formatted_bins"""
"""Test the helper function _ensure_correctly_formatted_bins"""
bins_in, n, bins_expected = in_out
if bins_expected is not None:
bins = _ensure_correctly_formatted_bins(bins_in, n)
Expand All @@ -277,7 +305,7 @@ def test_ensure_correctly_formatted_bins(in_out):
],
)
def test_ensure_correctly_formatted_range(in_out):
""" Test the helper function _ensure_correctly_formatted_range"""
"""Test the helper function _ensure_correctly_formatted_range"""
range_in, n, range_expected = in_out
if range_expected is not None:
range_ = _ensure_correctly_formatted_range(range_in, n)
Expand Down