Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor histogram to use blockwise #49

Merged
merged 14 commits into from
Jun 8, 2021

Conversation

rabernat
Copy link
Contributor

@rabernat rabernat commented Apr 30, 2021

As discussed in ocean-transport/coiled_collaboration#8, I feel like there must be a more efficient implementation of xhistogram's core logic in terms of handling dask arrays. A central problem with our implementation is the use of reshape on dask arrays. This causes inefficient chunk manipulations (see dask/dask#5544).

We should be able to just apply the bin_count function to every block and then sum over all blocks to get the total bin count. If we go this route, we will no longer use dask.array.histogram at all. This should result in much simpler dask graphs.

I have not actual figured out how to solve the problem, but this PR gets started in that direction. The challenge is in determining the arguments to pass to map_blocks.

@dcherian
Copy link

dcherian commented Apr 30, 2021

👍 You know the number of bins or groups so you can "tree reduce" it without any reshaping as you say. The current version of dask's bincount does exactly that for the 1d case:
https://github.com/dask/dask/blob/98a65849a6ab2197ac9e24838d1effe701a441dc/dask/array/routines.py#L677-L698

It's what inspired the core of my groupby experiment: https://github.com/dcherian/dask_groupby/blob/373a4dda4d618848b5d02cca551893fee36df092/dask_groupby/core.py#L361-L368 which only reshapes if you know absolutely nothing about the "expected" groups here: https://github.com/dcherian/dask_groupby/blob/373a4dda4d618848b5d02cca551893fee36df092/dask_groupby/core.py#L604-L608

IIUC your proposal is implemented in tensordot as a blockwise-which-adds-dummy-axis followed by sum.
https://github.com/dask/dask/blob/98a65849a6ab2197ac9e24838d1effe701a441dc/dask/array/routines.py#L295-L310

I found blockwise a lot easier to handle than map_blocks for this stuff

@jrbourbeau
Copy link
Contributor

cc @gjoseph92

@rabernat
Copy link
Contributor Author

@dcherian - at this point it is clear that you know a lot more about this than me. Would you have any interest in taking a crack at implementing the approach you just described? I think this PR sets it up nicely so all you have to do is plug in the tree reduce code.

Comment on lines 300 to 314
bin_counts = dsa.map_blocks(
_bincount,
all_arrays,
weights,
axis,
bins,
density,
drop_axis=axis,
new_axis="what?",
chunks="what?",
meta=np.array((), dtype=np.int64),
)
# sum over the block dims
block_dims = "?"
bin_counts = bin_counts.sum(block_dims)
Copy link
Contributor Author

@rabernat rabernat Apr 30, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is really where the new implementation needs to go. If not using map_blocks then something else that does the equivalent thing.

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well I tried to start looking at it but nothing works without #48 being fixed

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There's an upstream fix in dask (xref dask/dask#7594) which should be merged soon

@gjoseph92
Copy link
Contributor

Okay, I think this is the skeleton for how we could implement it: https://gist.github.com/gjoseph92/1ceaf28937f485012bc1540eb2e8da7d

We just need the actual bincount-ish function to be applied over the chunk pairs. And to handle weights, and the bins being dask arrays. But I think those are "just" a matter of broadcasting and lining up the dimension names to what I have here.

Now that I'm seeing how much cleaner the graphs can be in the n-dimensional chunking case (I think the examples I'd looked at earlier all had one chunk along the aggregated dimensions), I fully agree with @rabernat that blockwise could improve the graph structure a lot.

@dcherian
Copy link

dcherian commented May 4, 2021

👍 That's exactly it AFAICT. It's the same trick as tensordot and isin.

@rabernat
Copy link
Contributor Author

rabernat commented May 6, 2021

@gjoseph92 that looks amazing. Exactly what I had in mind originally with this PR but didn't have the dask chops to actually implement. Let me know how I can help move this forward.

@TomNicholas
Copy link
Member

I'm happy to have a go at implementing this too. I would eventually like to integrate it upstream in xarray, but only once we've got the best dask function for the job working.

One question - what would the numpy equivalent of @gjoseph92 's solution be? Is it just to apply bincount once over the whole array without calling blockwise?

@dougiesquire
Copy link
Contributor

I'm happy to have a go at implementing this too. I would eventually like to integrate it upstream in xarray, but only once we've got the best dask function for the job working.

@TomNicholas, for someone that's not up with the lingo, do you mean move histogram functionality into xarray?

One question - what would the numpy equivalent of @gjoseph92 's solution be? Is it just to apply bincount once over the whole array without calling blockwise?

For 1D arrays at least, I think numpy histogram actually loops over blocks and cumulatively sums the bincount applied to each block, with a note that this is both faster and more memory efficient for large arrays.

@TomNicholas
Copy link
Member

do you mean move histogram functionality into xarray

Yes exactly. It's arguably within the scope of xarray (which already has a .hist method, it just currently reduces over all dimensions), and has functions for groupby operations and similar types of reshaping already.

In the long term the advantages would be:

  • Domain-agnostic functionality would be living in a domain-agnostic library
  • More users would use the code & benefit from it
  • More users means more chance of finding bugs, and more help with extending to add new features
  • Not having the maintenance burden of a separate package

I think numpy histogram actually loops over blocks

Oh that's interesting, thanks!

@dougiesquire
Copy link
Contributor

Yes exactly. It's arguably within the scope of xarray

I think this is a terrific idea. Happy to help out however I can

@TomNicholas
Copy link
Member

I think this is a terrific idea. Happy to help out however I can

Great! Well anyone can do the PR(s) to xarray - I would be happy to provide guidance & code reviews over there if you wanted to try it?

The steps as I see it are:

  • Get the dask.blockwise bare version working efficiently, probably here (could build off Ryan's PR in another PR if we want)
  • Discuss the external API we want in a "feature proposal" issue on xarray (I'll open one for that with some ideas)
  • Copy over whatever tests + docs are transferable to an xarray PR
  • Get that minimum version working in the xarray PR
  • Once the minimum version works in xarray I would probably only worry about adding weighted functionality and so on in a later PR
  • Eventually deprecate xhistogram
  • profit

@dougiesquire
Copy link
Contributor

Yup I think it's a good idea to get the dask.blockwise implementation working well here as a first step 👍

@TomNicholas
Copy link
Member

So looking at the code in xhistogram.core, I have some questions:

To implement @gjoseph92 's solution (which looks great) then we would need to write a new bincount function that matches the signature described in his notebook. The _bincount function we currently have has a different signature (though I'm not sure exactly what signature it has because it's an internal function so doesn't have a docstring), so would need to be significantly changed. Then _bincount calls these other functions like _bincount_2d_vectorized, which I think is specific to the reshape dask algorithm we are currently using? If we are changing all these functions are we essentially just rewriting the whole of xhistogram.core? I'm not saying we shouldn't do that, I'm just trying to better understand what's required. Can any of this be copied over from other implementations (e.g. dask.array.histogramdd)?

@gjoseph92
Copy link
Contributor

If we are changing all these functions are we essentially just rewriting the whole of xhistogram.core

Yes, I think we sort of are. xhistogram's current approach is dask-agnostic (ish, besides _dispatch_bincount)—it uses a bunch of NumPy functions and lets dask parallelize them implicitly. With the blockwise approach, we'd sort of have a shared "inner kernel" (the core bincount operation), then slightly different code paths for mapping that over a NumPy array versus explicitly parallelizing it over a dask array.

The difference is very small: for a NumPy array, apply the core bincount function directly to the array. For a dask array, use blockwise to apply it. (You can think of the NumPy case as just a 1-chunk dask array.) Then in both cases, finish with bincounts.sum(axis=axis).

So in a way, the bincount "inner kernel" is all of the current xhistogram core, just simplified to not have to worry about dask arrays, and to return a slightly different shape. I think this change might actually make xhistogram a bit easier to read/reason about. Additionally, I think we should consider writing this bincount part in numba or Cython eventually—not just because it's faster, but because being able to just write a for loop would be way easier to read than the current reshape and ravel_multi_index logic, which is hard to wrap your head around.

Also, it might be worth passing NumPy arrays of sufficient size (or any size) into dask.array.from_array, putting them through the dask logic, and computing the result with scheduler="threads". This would be a quick and easy way to get the loop-over-blocks behavior @dougiesquire was describing, (maybe) a lower peak memory footprint, plus some parallelism.

@TomNicholas
Copy link
Member

That's very helpful, thanks @gjoseph92 .

I think this change might actually make xhistogram a bit easier to read/reason about.

I agree.

Additionally, I think we should consider writing this bincount part in numba or Cython eventually

That would be cool, but it would also affect whether or not we could integrate the functionality upstream in xarray, so I suggest that we get a version working with just dask initially so we can see if it's fast "enough"?

computing the result with scheduler="threads"

That might even be like a "best of both" solution, ish?

@dougiesquire
Copy link
Contributor

Thanks @gjoseph92 for that very helpful description!

@TomNicholas I'm going on leave next week and I'm hoping to work on this a little during that. My plan is to spend some time understanding/implementing the blockwise approach and then to open another PR once things are somewhat functional. But I want to make sure I don't tread on your toes or duplicate your efforts. What's the best way to align our efforts?

@TomNicholas
Copy link
Member

TomNicholas commented May 12, 2021

I'm hoping to work on this a little during that

Great!

But I want to make sure I don't tread on your toes or duplicate your efforts.

@dougiesquire I wouldn't worry about that too much. I think we will both want to understand the resultant code, so time spent understanding it is always well-spent. For development we should just both post whatever we try out - as a gist, a PR, or could even start a new dev branch on this repo.


bin_counts = _bincount_2d_vectorized(
*all_arrays_reshaped, bins=bins, weights=weights_reshaped, density=density
)
Copy link
Contributor Author

@rabernat rabernat May 14, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is not obvious but: bin_counts is not necessarily 2D. Its ndim is 1 + len(all_arrays). So it represents an ND historgram.

@rabernat rabernat marked this pull request as ready for review May 15, 2021 02:41
@TomNicholas
Copy link
Member

With weights you get this:

csize = 80
nz = 20
ntime = 2
fac = 4

bins1 = np.arange(10)
bins2 = np.arange(8)

data1 = dsa.random.random((nz, fac * csize, fac * csize),
                          chunks=(nz, csize, csize))
data2 = dsa.random.random((nz, fac * csize, fac * csize),
                          chunks=(nz, csize, csize))
weights = dsa.random.random((nz, fac * csize, fac * csize),
                          chunks=(nz, csize, csize))

display(data)
hist = xhc.histogram(data1, data2, bins=[bins1, bins2], weights=weights, axis=[0, 1])
hist_eager = xhc.histogram(data1.compute(), data2.compute(), bins=[bins1, bins2], weights=weights.compute(), axis=[0, 1])
display(hist)
display(hist.visualize())
np.testing.assert_allclose(hist.compute(), hist_eager)

image

@jbusecke
Copy link
Contributor

Just beautiful! I suspect this will help tremendously over at ocean-transport/coiled_collaboration#9.

Do you think this is ready to test under real-world conditions?

@TomNicholas
Copy link
Member

Do you think this is ready to test under real-world conditions?

You could try it, but probably not, because chunking has no explicit tests, apart from the ones I added in #57, which do not all yet pass 😕

@TomNicholas
Copy link
Member

@jbusecke scratch that - it looks like it was my tests that were the problem after all - should be green to try this out now!

@jbusecke
Copy link
Contributor

jbusecke commented May 27, 2021

Whoa! This works really well!

Dask stream from a synthetic example, which would fail often (ocean-transport/coiled_collaboration#9). This example uses two variables and weights.
image

I ran this on the google deployment with the 'hacky pip install' described here. Will try it out for realistic CMIP6 data next.

I can't thank you enough for implementing this! I know who I will thank at the next CDSLab meeting!

@jbusecke
Copy link
Contributor

Ok so after some tuning of the chunksize I got some pretty darn satisfying results for the real world example:
image

This crunched through a single member in under 2 minutes with 20 workers. Absolutely no spilling!

else:
import dask.array as dsa

return any([isinstance(a, dsa.core.Array) for a in args])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
return any([isinstance(a, dsa.core.Array) for a in args])
return any(isinstance(a, dsa.core.Array) for a in args)

nit: don't need to construct an intermediate list

Comment on lines 294 to 298
# here I am assuming all the arrays have the same shape
# probably needs to be generalized
input_indexes = [tuple(range(a.ndim)) for a in all_arrays]
input_index = input_indexes[0]
assert all([ii == input_index for ii in input_indexes])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: it seems this is only needed in the if _any_dask_array case; maybe move it into that conditional?

here I am assuming all the arrays have the same shape

Maybe you'd want to np.broadcast_arrays(*all_arrays) here (note that this will dispatch to da.broadcast_arrays automatically thanks to NEP-18)?

Alternatively, I think you could get away without broadcasting with some slight modifications to the blockwise indices. Specifically, rather than using input_index for all arrays, use each array's own indices. Then blockwise will basically express any the broadcasting for you in the graph:

input_indexes = [tuple(range(a.ndim)) for a in all_arrays]
largest_index = max(input_indexes, key=len)
if axis is not None:
    drop_axes = tuple([ii for ii in largest_index if ii in axis])
else:
    drop_axes = largest_index

out_index = largest_index + tuple(new_axes)
...
blockwise_args = []
for arg, index in zip(all_arrays, input_indexes):
    blockwise_args.append(arg)
    blockwise_args.append(index)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@gjoseph92 is there any performance reason to prefer making blockwise handle the broadcasting? It seems simplest to just broadcast everything at the start, but does that mean carrying around larger arrays than necessary?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The arrays shouldn't actually increase in physical size; their shape will just be expanded. The only advantage of using blockwise alone is that the graph will be a little smaller:

  • No extra reshape layers added prior to the blockwise (which are also fully materialized layers and slightly more expensive to serialize to the scheduler)
  • The low-level graphs of the arrays input into blockwise won't contain superfluous keys for the extra broadcast dimensions

In practice, I very much doubt it would be an appreciable difference, but not having compared them I can't say for sure. There is also a slight optimization advantage with pure-blockwise when there is exactly 1 input array, and it's already the result of a blockwise operation.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's interesting, thanks.

However it's just occurred to me that surely relying on blockwise to do all the broadcasting is not an option, simply because the non-dask case doesn't use blockwise. Supporting a numpy-only code path (as we would likely want both for debugging and for fitting xarray's standard practice) means broadcasting all arrays before they are fed to _bincount.

Comment on lines +152 to +153
# is this necessary?
all_arrays_broadcast = broadcast_arrays(*all_arrays)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you go with the change suggested in #49 (comment) (no np.broadcast in top-level histogram), then I think this broadcast is necessary. Otherwise, if a0 happens to have dimensions that would be added or expanded by broadcasting, then your kept_axes_shape calculation could be wrong, and reshape_input might move the wrong axes.

However, if you added an np.broadcast in top-level histogram, then I think you could remove this one here.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So with this PR, none of this code path will ever see dask arrays. So we can probably remove the duck_array_ops module.

@TomNicholas
Copy link
Member

TomNicholas commented May 28, 2021

So now that we've successfully chunked over all dims, do we still need any of the "block size" logic? The main benefit of being able to specify block_size is a lowered memory footprint right? But if we can now chunk over any dim then the dask chunks can always be made smaller for a given problem, and the user can always lower their memory footprint by using smaller chunks instead?

Or do we still want it for the same reasons that numpy.histogram has it?

(There is also no block used anywhere in np.histogramdd)

@rabernat
Copy link
Contributor Author

rabernat commented Jun 1, 2021

Or do we still want it for the same reasons that numpy.histogram has it?

My understanding was that this is a low-level optimization, aimed at making the core bincount routine faster. Dask is responsible for scaling out, but having a faster low-level routine is always desireable as well. That said, I have not benchmarked anything to verify it actually makes a difference.

Comment on lines 381 to 382
np.histogram_bin_edges(a, b, r, weights)
for a, b, r in zip(all_arrays, bins, range)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I brought this code in when I merged with upstream/master and picked up the new bin edges logic introduced by @dougiesquire in #45. Unfortunately the tests are going to fail because of shape mismatches between arrays and weights. Due to the refactor in this PR, the high-level histogram function no longer has access to the reshaped / broadcasted arrays. So we will need some new solution here.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe it's cleanest/easiest then to np.broadcast_arrays(*all_arrays) early on as @gjoseph92 suggests in #49 (comment)? Then we also don't need to worry about each array potentially having a different input_index.

Another option I guess would be to do the broadcasting in xarray.histogram and have core.histogram expect that all args and weights are the same shape (somewhat replicating what numpy.histogram expects).

Copy link
Member

@TomNicholas TomNicholas Jun 4, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@rabernat - the docstring of np.histogram_bin_edges says that the weights argument "is currently not used by any of the bin estimators, but may be in the future." So it currently doesn't do anything: you can make all the tests pass again simply by deleting the weights argument from the call to np.histogram_bin_edges.

Separately I do think that we should consider broadcasting early on to make the indexing easier to follow, but it would be nice to get this PR merged and discuss that refactoring elsewhere.

xhistogram/core.py Outdated Show resolved Hide resolved
Co-authored-by: Tom Nicholas <35968931+TomNicholas@users.noreply.github.com>
@rabernat
Copy link
Contributor Author

rabernat commented Jun 8, 2021

Tom, the only thing that failed after I accepted your suggestion was the linting. So please go ahead and merge this if you think it's ready! Lots of stuff is blocking on this PR.

@jbusecke
Copy link
Contributor

🚀

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

7 participants