Skip to content

Commit

Permalink
Use PyWavelet for CWT to be compatible with newer scipy versions (#1097)
Browse files Browse the repository at this point in the history
  • Loading branch information
igor-pechersky authored Feb 16, 2025
1 parent fdfeaae commit ebcfba6
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 5 deletions.
5 changes: 3 additions & 2 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -37,10 +37,11 @@ install_requires =
requests>=2.9.1
numpy >=1.15.1
pandas>=0.25.0
scipy>=1.14.0,<1.15;python_version>='3.10'
scipy>=1.2.0,<1.15;python_version<'3.10'
scipy>=1.14.0;python_version>='3.10'
scipy>=1.2.0;python_version<'3.10'
statsmodels>=0.13
patsy>=0.4.1
pywavelets
scikit-learn>=0.22.0
tqdm>=4.10.0
stumpy>=1.7.2
Expand Down
21 changes: 18 additions & 3 deletions tsfresh/feature_extraction/feature_calculators.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,10 @@

import numpy as np
import pandas as pd
import pywt
import stumpy
from numpy.linalg import LinAlgError
from scipy.signal import cwt, find_peaks_cwt, ricker, welch
from scipy.signal import find_peaks_cwt, welch
from scipy.stats import linregress
from statsmodels.tools.sm_exceptions import MissingDataError
from statsmodels.tsa.ar_model import AutoReg
Expand Down Expand Up @@ -1303,6 +1304,18 @@ def index_mass_quantile(x, param):
]


def _ricker(points, a):
"""Custom implementation of the ricker wavelet, copied from scipy as scipy dropped it."""
A = 2 / (np.sqrt(3 * a) * (np.pi**0.25))
wsq = a**2
vec = np.arange(0, points) - (points - 1.0) / 2
xsq = vec**2
mod = 1 - xsq / wsq
gauss = np.exp(-xsq / (2 * wsq))
total = A * mod * gauss
return total


@set_property("fctype", "simple")
def number_cwt_peaks(x, n):
"""
Expand All @@ -1320,7 +1333,9 @@ def number_cwt_peaks(x, n):
:return type: int
"""
return len(
find_peaks_cwt(vector=x, widths=np.array(list(range(1, n + 1))), wavelet=ricker)
find_peaks_cwt(
vector=x, widths=np.array(list(range(1, n + 1))), wavelet=_ricker
)
)


Expand Down Expand Up @@ -1384,7 +1399,7 @@ def cwt_coefficients(x, param):
coeff = parameter_combination["coeff"]

if widths not in calculated_cwt:
calculated_cwt[widths] = cwt(x, ricker, widths)
calculated_cwt[widths], _ = pywt.cwt(x, scales=widths, wavelet="mexh")

calculated_cwt_for_widths = calculated_cwt[widths]

Expand Down

0 comments on commit ebcfba6

Please sign in to comment.