Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow skipna in .dot() #4482

Open
heerad opened this issue Oct 2, 2020 · 13 comments
Open

Allow skipna in .dot() #4482

heerad opened this issue Oct 2, 2020 · 13 comments

Comments

@heerad
Copy link

heerad commented Oct 2, 2020

Is your feature request related to a problem? Please describe.
Right now there's no efficient way to do a dot product that skips over nan elements.

Describe the solution you'd like
I want to be able to treat the summation in dot as a nansum, controlled by a skipna option. Either this can be implemented directly, or an additional ufunc can be added: xarray.unfuncs.nan_to_num that can be called on the inputs to dot. Unfortunately using numpy's nan_to_num will initiate eager execution.

Describe alternatives you've considered
It's possible to implement this by hand, but it ends up being extremely inefficient in one of my use-cases:

  • (x*y).sum('dot_prod_dim', skipna=True) takes 30 seconds
  • x.dot(y) takes 1 second
@max-sixty
Copy link
Collaborator

I agree this would be a welcome option.

As a workaround, you could fillna with zero?

@heerad
Copy link
Author

heerad commented Oct 2, 2020

Great, looks like I missed that option. Thanks.

For reference, x.fillna(0).dot(y) takes 18 seconds in that same example, so a little better.

@mathause
Copy link
Collaborator

mathause commented Oct 2, 2020

Yes that would be very helpful. This is used for the weighted operations:

if skipna or (skipna is None and da.dtype.kind in "cfO"):
da = da.fillna(0.0)
# `dot` does not broadcast arrays, so this avoids creating a large
# DataArray (if `weights` has additional dimensions)
# maybe add fasttrack (`(da * weights).sum(dims=dim, skipna=skipna)`)
return dot(da, weights, dims=dim)

and it would be great if this could be done upstream. However, dot is implemented using np.einsum which is quite a gnarly beast.

@shoyer
Copy link
Member

shoyer commented Oct 6, 2020

I agree this would be welcome! Even if it isn't much faster than the options already shown here, at least we could point users to the best option we know of.

I suspect achieving the full speed of dot() with skip-NA support is impossible, but we can probably do much better. I might start by prototyping something in Numba, just to get a sense of what is achievable with a low-level approach. But keep in mind that functions like np.dot and np.einsum ("GEMM") are a few of the most highly optimized routines in numerical computing.

@max-sixty
Copy link
Collaborator

Any idea why x.fillna(0.) is so much more expensive than x.dot(y)?

For some functions where fillna(0.) has a different result to skipna=True, I could see it being worthwhile to writing new routines. But for dot product, it's definitionally the same. And fillna seems a much simpler operation than dot...

@mathause
Copy link
Collaborator

mathause commented Oct 9, 2020

fillna is basically where(notnull(data), data, other). I think what takes the longest is the where part - possibly making a memory copy(?).

@max-sixty
Copy link
Collaborator

Maybe on very small arrays it's quicker to do a product than a copy? As the array scales, it's surely not — dot product is O(n^3) or similar. I would be interested to see a repro...

@heerad
Copy link
Author

heerad commented Oct 12, 2020

Adding on here, even if fillna were to create a memory copy, we'd only expect memory usage to double. However, in my case with dask-based chunking (via parallel=True in open_mfdataset) I'm seeing the memory blow up multiple times that (10x+) until all available memory is eaten up.

This is happening with x.fillna(0).dot(y) as well as x.notnull().dot(y) and x.weighted(y).sum(skipna=True). x is the array that's chunked. This suggests that dask-based chunking isn't following through into the fillna and notnull ops, and the entire non-chunked arrays are being computed.

More evidence in favor: if I do (x*y).sum(skipna=True) I get the following error:

MemoryError: Unable to allocate [xxx] GiB for an array with shape [un-chunked array shape] and data type float64

I'm happy to live with a memory copy for now with fillna and notnull, but allocating the full, un-chunked array into memory is a showstopper. Is there a different workaround that I can use in the meantime?

@shoyer
Copy link
Member

shoyer commented Oct 12, 2020

I'm happy to live with a memory copy for now with fillna and notnull, but allocating the full, un-chunked array into memory is a showstopper. Is there a different workaround that I can use in the meantime?

This is surprising behavior, and definitely sounds like a bug!

If you could put together a minimal test case for reproducing the issue, we could look into it. It's hard to say what a work-around would be without knowing the source of the issue.

@heerad
Copy link
Author

heerad commented Oct 12, 2020

See below. I temporarily write some files to netcdf then recombine them lazily using open_mfdataset.

The issue seems to present itself more consistently when my x is a constructed rolling window, and especially when it's a rolling window of a stacked dimension as in below.

I used the memory_profiler package and associated notebook extension (%%memit cell magic) to do memory profiling.

import numpy as np
import xarray as xr
import os

N = 1000
N_per_file = 10
M = 100
K = 10
window_size = 150

tmp_dir = 'tmp'

os.mkdir(tmp_dir)

# save many netcdf files, later to be concatted into a dask.delayed dataset
for i in range(0, N, N_per_file):

    # 3 dimensions:
    # d1 is the dim we're splitting our files/chunking along
    # d2 is a common dim among all files/chunks
    # d3 is a common dim among all files/chunks, where the first half is 0 and the second half is nan
    x_i = xr.DataArray([[[0]*(K//2) + [np.nan]*(K//2)]*M]*N_per_file,
        [('d1', [x for x in range(i, i+N_per_file)]), 
         ('d2', [x for x in range(M)]),
         ('d3', [x for x in range(K)])]

    x_i.to_dataset(name='vals').to_netcdf('{}/file_{}.nc'.format(tmp_dir,i))

# open lazily
x = xr.open_mfdataset('{}/*.nc'.format(tmp_dir), parallel=True, concat_dim='d1').vals

# a rolling window along a stacked dimension
x_windows = x.stack(d13=['d1', 'd3']).rolling(d13=window_size).construct('window')

# we'll dot x_windows with y along the window dimension
y = xr.DataArray([1]*window_size, dims='window')

# incremental memory: 1.94 MiB
x_windows.dot(y).compute()

# incremental memory: 20.00 MiB
x_windows.notnull().dot(y).compute()

# incremental memory: 182.13 MiB
x_windows.fillna(0.).dot(y).compute()

# incremental memory: 211.52 MiB
x_windows.weighted(y).mean('window', skipna=True).compute()

@max-sixty
Copy link
Collaborator

Right — that makes sense now. Given that .fillna(0) creates a copy, when we're doing stride tricks in the form of construct then that copy can be huge.

So I think there's a spectrum of implementations of skipna, o/w two are:

  • as a convenient alias of .fillna(0), like @mathause 's example above IIUC
  • ensuring a copy isn't made, which may require diving into np.einops (or open to alternatives)

The second would be required for @heerad 's case above

@heerad
Copy link
Author

heerad commented Oct 14, 2020

Adding on, whatever the solution is that avoids blowing up memory, especially when using with construct, it would be useful to be implemented for both fillna(0) and notnull(). One common use-case would be so that you can take a weighted mean which normalizes by the sum of weights corresponding only to non-null entries, as in here:

def _weighted_mean(

@heerad
Copy link
Author

heerad commented Oct 20, 2020

On the topic of fillna(), I'm seeing an odd unrelated issue that I don't have an explanation for.

I have a dataarray x that I'm able to call x.compute() on.

When I do x.fillna(0).compute(), I get the following error:

KeyError: ('where-3a3[...long hex string]', 100, 0, 0, 4)

Stack trace shows it's failing on a get_dependencies(dsk, key, task, as_list) call from a cull(dsk, keys) call in dask/optimization.py. get_dependencies itself is defined in dask/core.py.

I have no idea how to reproduce this simply... If it helps narrow things down, x is a dask array, one of the dimensions is a datetime64, and all other are strings. I've tried using both the default engine and netcdf4 when loading with open_mfdataset.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

5 participants