Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

weighted operations: performance optimisations #3883

Open
mathause opened this issue Mar 24, 2020 · 3 comments
Open

weighted operations: performance optimisations #3883

mathause opened this issue Mar 24, 2020 · 3 comments

Comments

@mathause
Copy link
Collaborator

There was a discussion on the performance of the weighted mean/ sum in terms of memory footprint but also speed, and there may indeed be some things that can be optimized. See the posts at the end of the PR. However, the optimal implementation will probably depend on the use case and some profiling will be required.

I'll just open an issue to keep track of this.
@seth-p

@mathause
Copy link
Collaborator Author

maybe relevant: #1995

@mathause
Copy link
Collaborator Author

mathause commented May 18, 2020

%load_ext line_profiler


import numpy as np
import xarray as xr

from xarray.core.weighted import Weighted as w

shape_weights = (1000, 1000)
shape_data = (1000, 1000, 10)
add_nans = False

def lprun_weighted(shape_weights, shape_data, add_nans, skipna=None):

    weights = xr.DataArray(np.random.randn(*shape_weights))

    data = np.random.randn(*shape_data)

    # add approximately 25 % NaNs
    if add_nans:
        c = int(data.size * 0.25)
        data.ravel()[np.random.choice(data.size, c, replace=False)] = np.NaN

    data = xr.DataArray(data)


    return data.weighted(weights).mean(skipna=skipna)



%lprun -f w._reduce -f w._weighted_mean -f w._sum_of_weights -f w._weighted_sum -f w.__init__ -f lprun_weighted -u 1e-03 lprun_weighted(shape_weights, shape_data, add_nans, skipna=None)

@mathause
Copy link
Collaborator Author

mathause commented May 26, 2020

weighted(weights).mean(skipna=True) calls

mask = da.notnull()

and

da = da.fillna(0.0)

da.fillna(0.0) in turn calls where(null(data), other). Thus null/ notnull is called twice. This could be optimized by doing self.null = obj.null() in __init__(...). This might also allow self.any_null = self.null.any() to skip NaN handling if there are none. However, this needs more thinking if a Dataset is passed. Probably overkill but leaving this here for reference.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

2 participants