diff --git a/ci/environment-3.7.yml b/ci/environment-3.7.yml index 355cfc8..64dca66 100644 --- a/ci/environment-3.7.yml +++ b/ci/environment-3.7.yml @@ -5,7 +5,7 @@ dependencies: - python=3.7 - xarray - dask - - numpy=1.16 + - numpy=1.17 - pytest - pip - pip: diff --git a/setup.py b/setup.py index 4e6f771..d119cce 100644 --- a/setup.py +++ b/setup.py @@ -20,7 +20,7 @@ "Topic :: Scientific/Engineering", ] -INSTALL_REQUIRES = ["xarray>=0.12.0", "dask", "numpy>=1.16"] +INSTALL_REQUIRES = ["xarray>=0.12.0", "dask", "numpy>=1.17"] PYTHON_REQUIRES = ">=3.7" DESCRIPTION = "Fast, flexible, label-aware histograms for numpy and xarray" diff --git a/xhistogram/core.py b/xhistogram/core.py index 97a1c7e..c13dda9 100644 --- a/xhistogram/core.py +++ b/xhistogram/core.py @@ -193,11 +193,7 @@ def _bincount_2d_vectorized( def _bincount(*all_arrays, weights=False, axis=None, bins=None, density=None): - - # is this necessary? - all_arrays_broadcast = broadcast_arrays(*all_arrays) - - a0 = all_arrays_broadcast[0] + a0 = all_arrays[0] do_full_array = (axis is None) or (set(axis) == set(_range(a0.ndim))) @@ -226,7 +222,7 @@ def reshape_input(a): d = reshape(c, (new_dim_0, new_dim_1)) return d - all_arrays_reshaped = [reshape_input(a) for a in all_arrays_broadcast] + all_arrays_reshaped = [reshape_input(a) for a in all_arrays] if weights: weights_array = all_arrays_reshaped.pop() @@ -257,8 +253,8 @@ def histogram( Parameters ---------- args : array_like - Input data. The number of input arguments determines the dimensonality - of the histogram. For example, two arguments prodocue a 2D histogram. + Input data. The number of input arguments determines the dimensionality + of the histogram. For example, two arguments produce a 2D histogram. All args must have the same size. bins : int, str or numpy array or a list of ints, strs and/or arrays, optional If a list, there should be one entry for each item in ``args``. @@ -358,21 +354,20 @@ def histogram( dtype = "i8" if not has_weights else weights.dtype - # here I am assuming all the arrays have the same shape - # probably needs to be generalized - input_indexes = [tuple(_range(a.ndim)) for a in all_arrays] - input_index = input_indexes[0] - assert all([ii == input_index for ii in input_indexes]) + # Broadcast input arrays. Note that this dispatches to `dsa.broadcast_arrays` as necessary. + all_arrays = broadcast_arrays(*all_arrays) + # Since all arrays now have the same shape, just get the axes of the first. + input_axes = tuple(_range(all_arrays[0].ndim)) # Some sanity checks and format bins and range correctly bins = _ensure_correctly_formatted_bins(bins, n_inputs) range = _ensure_correctly_formatted_range(range, n_inputs) - # histogram_bin_edges trigges computation on dask arrays. It would be possible + # histogram_bin_edges triggers computation on dask arrays. It would be possible # to write a version of this that doesn't trigger when `range` is provided, but # for now let's just use np.histogram_bin_edges if is_dask_array: - if not all([isinstance(b, np.ndarray) for b in bins]): + if not all(isinstance(b, np.ndarray) for b in bins): raise TypeError( "When using dask arrays, bins must be provided as numpy array(s) of edges" ) @@ -382,11 +377,11 @@ def histogram( ] bincount_kwargs = dict(weights=has_weights, axis=axis, bins=bins, density=density) - # keep these axes in the inputs + # remove these axes from the inputs if axis is not None: - drop_axes = tuple([ii for ii in input_index if ii in axis]) + drop_axes = tuple(axis) else: - drop_axes = input_index + drop_axes = input_axes if _any_dask_array(weights, *all_arrays): # We should be able to just apply the bin_count function to every @@ -405,16 +400,14 @@ def histogram( adjust_chunks = {i: (lambda x: 1) for i in drop_axes} - new_axes = { - max(input_index) + 1 + i: axis_len - for i, axis_len in enumerate([len(bin) - 1 for bin in bins]) - } - out_index = input_index + tuple(new_axes) + new_axes_start = max(input_axes) + 1 + new_axes = {new_axes_start + i: len(bin) - 1 for i, bin in enumerate(bins)} + out_index = input_axes + tuple(new_axes) blockwise_args = [] for arg in all_arrays: blockwise_args.append(arg) - blockwise_args.append(input_index) + blockwise_args.append(input_axes) bin_counts = dsa.blockwise( _bincount, @@ -432,7 +425,7 @@ def histogram( bin_counts = _bincount(*all_arrays, **bincount_kwargs).squeeze(drop_axes) if density: - # Normalise by dividing by bin counts and areas such that all the + # Normalize by dividing by bin counts and areas such that all the # histogram data integrated over all dimensions = 1 bin_widths = [np.diff(b) for b in bins] if n_inputs == 1: diff --git a/xhistogram/test/test_core.py b/xhistogram/test/test_core.py index e0df0ee..f6ebcc3 100644 --- a/xhistogram/test/test_core.py +++ b/xhistogram/test/test_core.py @@ -122,6 +122,34 @@ def test_histogram_results_2d(): np.testing.assert_array_equal(hist, h) +@pytest.mark.parametrize("dask", [False, True]) +def test_histogram_results_2d_broadcasting(dask): + nrows, ncols = 5, 20 + data_a = np.random.randn(ncols) + data_b = np.random.randn(nrows, ncols) + nbins_a = 9 + bins_a = np.linspace(-4, 4, nbins_a + 1) + nbins_b = 10 + bins_b = np.linspace(-4, 4, nbins_b + 1) + + if dask: + test_data_a = dsa.from_array(data_a, chunks=3) + test_data_b = dsa.from_array(data_b, chunks=(2, 7)) + else: + test_data_a = data_a + test_data_b = data_b + + h, _ = histogram(test_data_a, test_data_b, bins=[bins_a, bins_b]) + assert h.shape == (nbins_a, nbins_b) + + hist, _, _ = np.histogram2d( + np.broadcast_to(data_a, data_b.shape).ravel(), + data_b.ravel(), + bins=[bins_a, bins_b], + ) + np.testing.assert_array_equal(hist, h) + + def test_histogram_results_2d_density(): nrows, ncols = 5, 20 data_a = np.random.randn(nrows, ncols) @@ -228,7 +256,7 @@ def test_histogram_shape(use_dask, block_size): def test_histogram_dask(): - """ Test that fails with dask arrays and inappropriate bins""" + """Test that fails with dask arrays and inappropriate bins""" shape = 10, 15, 12, 20 b = empty_dask_array(shape, chunks=(1,) + shape[1:]) histogram(b, bins=bins_arr) # Should work when bins is all numpy arrays @@ -255,7 +283,7 @@ def test_histogram_dask(): ], ) def test_ensure_correctly_formatted_bins(in_out): - """ Test the helper function _ensure_correctly_formatted_bins""" + """Test the helper function _ensure_correctly_formatted_bins""" bins_in, n, bins_expected = in_out if bins_expected is not None: bins = _ensure_correctly_formatted_bins(bins_in, n) @@ -277,7 +305,7 @@ def test_ensure_correctly_formatted_bins(in_out): ], ) def test_ensure_correctly_formatted_range(in_out): - """ Test the helper function _ensure_correctly_formatted_range""" + """Test the helper function _ensure_correctly_formatted_range""" range_in, n, range_expected = in_out if range_expected is not None: range_ = _ensure_correctly_formatted_range(range_in, n)