Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add axis support to cwt #509

Merged
merged 3 commits into from
Aug 5, 2019
Merged

add axis support to cwt #509

merged 3 commits into from
Aug 5, 2019

Conversation

grlee77
Copy link
Contributor

@grlee77 grlee77 commented Aug 1, 2019

Users requested batched operation for cwt in #445. This can be done by adding an axis argument as in this PR. This PR allows the input data to be n-dimensional with batched operation over all axes aside from the specified cwt axis. For 1D data, the behavior is unchanged from before.

The final shape of the output for n-dimensional data becomes:
(len(scales),) + data.shape (i.e. the scales dimension is always first as it was for the 1D case previously)

For the 'conv' case implementation is via a simple for loop, but for the 'fft' case we do not have to repeat the FFT of the wavelet filter for each item in the batch, so there is a performance benefit to batched operation.

a subset of benchmark results.
first, for non-batch cases

              -------------------------------------------- ---------------------
                n     wavelet   max_scale       dtype         conv       fft    
              ====== ========= =========== =============== ========== ==========

               128      mexh        16      numpy.float32   2.01±0ms   5.13±0ms 
               128      mexh        16      numpy.float64   4.32±0ms   5.99±0ms 

               2048     shan       256      numpy.float32   397±0ms    68.0±0ms 
               2048     shan       256      numpy.float64   864±0ms    100±0ms  

a few batch (n_batch=5) cases

               128      mexh        16      numpy.float32   6.68±0ms   6.67±0ms 
               128      mexh        16      numpy.float64   5.46±0ms   7.39±0ms 

               2048     shan       256      numpy.float32   1.87±0s    178±0ms  
               2048     shan       256      numpy.float64   4.32±0s    291±0ms

Summary for the 2048/shan/float32 case:

with conv: n_batch=5 batch takes 4.7 times longer than the non-batch case

with fft:  n_batch=5 batch takes 1.8 times longer than the non-batch case

for the non-batch case, mode 'fft' is 5.8 times faster than 'conv'

for the n_batch=5 batch case, mode 'fft' is 10.5 times faster than 'conv'

closes #445

@grlee77 grlee77 added this to the v1.1 milestone Aug 1, 2019
@grlee77 grlee77 mentioned this pull request Aug 1, 2019
@grlee77
Copy link
Contributor Author

grlee77 commented Aug 1, 2019

I should note that the FFT results in the benchmarks above are for single-threaded scipy.fftpack as the fft backend. It is likely multi-threaded FFT backend would give additional benefit in batched operation.

@rgommers
Copy link
Member

rgommers commented Aug 5, 2019

Backwards-compatible, and seems like a nice usability improvement.

I mostly just reviewed the tests and description. axis=-1 is the right default. Going to merge this, looks good. Thanks @grlee77

@rgommers rgommers merged commit 9ba3a1c into PyWavelets:master Aug 5, 2019
@grlee77 grlee77 deleted the batch_cwt branch November 13, 2019 00:46
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Batch load for pywt.cwt
2 participants