diff --git a/src/ptwt/_util.py b/src/ptwt/_util.py index 414323f1..ec57097b 100644 --- a/src/ptwt/_util.py +++ b/src/ptwt/_util.py @@ -2,13 +2,16 @@ from __future__ import annotations +import functools import typing -from collections.abc import Sequence -from typing import Any, Callable, NamedTuple, Optional, Protocol, Union, cast, overload +import warnings +from collections.abc import Callable, Sequence +from typing import Any, NamedTuple, Optional, Protocol, Union, cast, overload import numpy as np import pywt import torch +from typing_extensions import ParamSpec, TypeVar from .constants import ( OrthogonalizeMethod, @@ -253,3 +256,55 @@ def _map_result( Union[list[WaveletDetailDict], list[WaveletDetailTuple2d]], result_lst ) return approx, *cast_result_lst + + +Param = ParamSpec("Param") +RetType = TypeVar("RetType") + + +def _deprecated_alias( + **aliases: str, +) -> Callable[[Callable[Param, RetType]], Callable[Param, RetType]]: + """Handle deprecated function and method arguments. + + Use as follows:: + + @deprecated_alias(old_arg='new_arg') + def myfunc(new_arg): + ... + + Adapted from https://stackoverflow.com/a/49802489 + """ + + def rename_kwargs( + func_name: str, + kwargs: Param.kwargs, + aliases: dict[str, str], + ) -> None: + """Rename deprecated kwarg.""" + for alias, new in aliases.items(): + if alias in kwargs: + if new in kwargs: + raise TypeError( + f"{func_name} received both {alias} and {new} as arguments!" + f" {alias} is deprecated, use {new} instead." + ) + warnings.warn( + message=( + f"`{alias}` is deprecated as an argument to `{func_name}`; use" + f" `{new}` instead." + ), + category=DeprecationWarning, + stacklevel=3, + ) + kwargs[new] = kwargs.pop(alias) + + def deco(f: Callable[Param, RetType]) -> Callable[Param, RetType]: + @functools.wraps(f) + def wrapper(*args: Param.args, **kwargs: Param.kwargs) -> RetType: + rename_kwargs(f.__name__, kwargs, aliases) + return f(*args, **kwargs) + + return wrapper + + return deco diff --git a/src/ptwt/matmul_transform.py b/src/ptwt/matmul_transform.py index 0aaec943..ed681a4d 100644 --- a/src/ptwt/matmul_transform.py +++ b/src/ptwt/matmul_transform.py @@ -17,6 +17,7 @@ from ._util import ( Wavelet, _as_wavelet, + _deprecated_alias, _is_boundary_mode_supported, _is_dtype_supported, _unfold_axes, @@ -182,12 +183,13 @@ class MatrixWavedec(BaseMatrixWaveDec): >>> coefficients = matrix_wavedec(data_torch) """ + @_deprecated_alias(boundary="boundary_orthogonalization") def __init__( self, wavelet: Union[Wavelet, str], level: Optional[int] = None, axis: Optional[int] = -1, - boundary: OrthogonalizeMethod = "qr", + boundary_orthogonalization: OrthogonalizeMethod = "qr", odd_coeff_padding_mode: BoundaryMode = "zero", ) -> None: """Create a sparse matrix fast wavelet transform object. @@ -202,8 +204,9 @@ def __init__( None. axis (int, optional): The axis we would like to transform. Defaults to -1. - boundary : The method used for boundary filter treatment, - see :data:`ptwt.constants.OrthogonalizeMethod`. Defaults to 'qr'. + boundary_orthogonalization: The method used to orthogonalize + boundary filters, see :data:`ptwt.constants.OrthogonalizeMethod`. + Defaults to 'qr'. odd_coeff_padding_mode: The constructed FWT matrices require inputs with even lengths. Thus, any odd-length approximation coefficients are padded to an even length using this mode, @@ -217,8 +220,8 @@ def __init__( """ self.wavelet = _as_wavelet(wavelet) self.level = level - self.boundary = boundary self.odd_coeff_padding_mode = odd_coeff_padding_mode + self.boundary_orthogonalization = boundary_orthogonalization if isinstance(axis, int): self.axis = axis @@ -231,7 +234,7 @@ def __init__( self.padded = False self.size_list: list[int] = [] - if not _is_boundary_mode_supported(self.boundary): + if not _is_boundary_mode_supported(self.boundary_orthogonalization): raise NotImplementedError if self.wavelet.dec_len != self.wavelet.rec_len: @@ -311,7 +314,7 @@ def _construct_analysis_matrices( an = construct_boundary_a( self.wavelet, curr_length, - boundary=self.boundary, + boundary_orthogonalization=self.boundary_orthogonalization, device=device, dtype=dtype, ) @@ -402,11 +405,12 @@ def __call__(self, input_signal: torch.Tensor) -> list[torch.Tensor]: return result_list +@_deprecated_alias(boundary="boundary_orthogonalization") def construct_boundary_a( wavelet: Union[Wavelet, str], length: int, device: Union[torch.device, str] = "cpu", - boundary: OrthogonalizeMethod = "qr", + boundary_orthogonalization: OrthogonalizeMethod = "qr", dtype: torch.dtype = torch.float64, ) -> torch.Tensor: """Construct a boundary-wavelet filter 1d-analysis matrix. @@ -415,8 +419,9 @@ def construct_boundary_a( wavelet (Wavelet or str): A pywt wavelet compatible object or the name of a pywt wavelet. length (int): The number of entries in the input signal. - boundary : The method used for boundary filter treatment, - see :data:`ptwt.constants.OrthogonalizeMethod`. Defaults to 'qr'. + boundary_orthogonalization: The method used to orthogonalize + boundary filters, see :data:`ptwt.constants.OrthogonalizeMethod`. + Defaults to 'qr'. device: Where to place the matrix. Choose cpu or cuda. Defaults to cpu. dtype: Choose float32 or float64. @@ -426,15 +431,16 @@ def construct_boundary_a( """ wavelet = _as_wavelet(wavelet) a_full = _construct_a(wavelet, length, dtype=dtype, device=device) - a_orth = orthogonalize(a_full, wavelet.dec_len, method=boundary) + a_orth = orthogonalize(a_full, wavelet.dec_len, method=boundary_orthogonalization) return a_orth +@_deprecated_alias(boundary="boundary_orthogonalization") def construct_boundary_s( wavelet: Union[Wavelet, str], length: int, device: Union[torch.device, str] = "cpu", - boundary: OrthogonalizeMethod = "qr", + boundary_orthogonalization: OrthogonalizeMethod = "qr", dtype: torch.dtype = torch.float64, ) -> torch.Tensor: """Construct a boundary-wavelet filter 1d-synthesis matarix. @@ -445,8 +451,9 @@ def construct_boundary_s( length (int): The number of entries in the input signal. device (torch.device): Where to place the matrix. Choose cpu or cuda. Defaults to cpu. - boundary : The method used for boundary filter treatment, - see :data:`ptwt.constants.OrthogonalizeMethod`. Defaults to 'qr'. + boundary_orthogonalization: The method used to orthogonalize + boundary filters, see :data:`ptwt.constants.OrthogonalizeMethod`. + Defaults to 'qr'. dtype: Choose torch.float32 or torch.float64. Defaults to torch.float64. @@ -455,7 +462,9 @@ def construct_boundary_s( """ wavelet = _as_wavelet(wavelet) s_full = _construct_s(wavelet, length, dtype=dtype, device=device) - s_orth = orthogonalize(s_full.transpose(1, 0), wavelet.rec_len, method=boundary) + s_orth = orthogonalize( + s_full.transpose(1, 0), wavelet.rec_len, method=boundary_orthogonalization + ) return s_orth.transpose(1, 0) @@ -476,11 +485,12 @@ class MatrixWaverec(object): >>> reconstruction = matrix_waverec(coefficients) """ + @_deprecated_alias(boundary="boundary_orthogonalization") def __init__( self, wavelet: Union[Wavelet, str], axis: int = -1, - boundary: OrthogonalizeMethod = "qr", + boundary_orthogonalization: OrthogonalizeMethod = "qr", ) -> None: """Create the inverse matrix-based fast wavelet transformation. @@ -491,8 +501,9 @@ def __init__( for possible choices. axis (int): The axis transformed by the original decomposition defaults to -1 or the last axis. - boundary : The method used for boundary filter treatment, - see :data:`ptwt.constants.OrthogonalizeMethod`. Defaults to 'qr'. + boundary_orthogonalization: The method used to orthogonalize + boundary filters, see :data:`ptwt.constants.OrthogonalizeMethod`. + Defaults to 'qr'. Raises: NotImplementedError: If the selected `boundary` mode is not supported. @@ -500,7 +511,7 @@ def __init__( axis is not an integer. """ self.wavelet = _as_wavelet(wavelet) - self.boundary = boundary + self.boundary_orthogonalization = boundary_orthogonalization if isinstance(axis, int): self.axis = axis else: @@ -511,7 +522,7 @@ def __init__( self.input_length: Optional[int] = None self.padded = False - if not _is_boundary_mode_supported(self.boundary): + if not _is_boundary_mode_supported(self.boundary_orthogonalization): raise NotImplementedError if self.wavelet.dec_len != self.wavelet.rec_len: @@ -591,7 +602,7 @@ def _construct_synthesis_matrices( sn = construct_boundary_s( self.wavelet, curr_length, - boundary=self.boundary, + boundary_orthogonalization=self.boundary_orthogonalization, device=device, dtype=dtype, ) diff --git a/src/ptwt/matmul_transform_2.py b/src/ptwt/matmul_transform_2.py index 8ab525e8..016f4ed7 100644 --- a/src/ptwt/matmul_transform_2.py +++ b/src/ptwt/matmul_transform_2.py @@ -17,6 +17,7 @@ _as_wavelet, _check_axes_argument, _check_if_tensor, + _deprecated_alias, _is_boundary_mode_supported, _is_dtype_supported, _map_result, @@ -148,12 +149,13 @@ def _construct_s_2( return transpose_synthesis +@_deprecated_alias(boundary="boundary_orthogonalization") def construct_boundary_a2( wavelet: Union[Wavelet, str], height: int, width: int, device: Union[torch.device, str], - boundary: OrthogonalizeMethod = "qr", + boundary_orthogonalization: OrthogonalizeMethod = "qr", dtype: torch.dtype = torch.float64, ) -> torch.Tensor: """Construct a boundary fwt matrix for the input wavelet. @@ -167,8 +169,9 @@ def construct_boundary_a2( Should be divisible by two. device (torch.device): Where to place the matrix. Either on the CPU or GPU. - boundary : The method used for boundary filter treatment, - see :data:`ptwt.constants.OrthogonalizeMethod`. Defaults to 'qr'. + boundary_orthogonalization: The method used to orthogonalize + boundary filters, see :data:`ptwt.constants.OrthogonalizeMethod`. + Defaults to 'qr'. dtype (torch.dtype, optional): The desired data type for the matrix. Defaults to torch.float64. @@ -177,17 +180,20 @@ def construct_boundary_a2( """ wavelet = _as_wavelet(wavelet) a = _construct_a_2(wavelet, height, width, device, dtype=dtype, mode="sameshift") - orth_a = orthogonalize(a, wavelet.dec_len**2, method=boundary) # noqa: BLK100 + orth_a = orthogonalize( + a, wavelet.dec_len**2, method=boundary_orthogonalization + ) # noqa: BLK100 return orth_a +@_deprecated_alias(boundary="boundary_orthogonalization") def construct_boundary_s2( wavelet: Union[Wavelet, str], height: int, width: int, device: Union[torch.device, str], *, - boundary: OrthogonalizeMethod = "qr", + boundary_orthogonalization: OrthogonalizeMethod = "qr", dtype: torch.dtype = torch.float64, ) -> torch.Tensor: """Construct a 2d-fwt matrix, with boundary wavelets. @@ -198,8 +204,9 @@ def construct_boundary_s2( height (int): The original height of the input matrix. width (int): The width of the original input matrix. device (torch.device): Choose CPU or GPU. - boundary : The method used for boundary filter treatment, - see :data:`ptwt.constants.OrthogonalizeMethod`. Defaults to 'qr'. + boundary_orthogonalization: The method used to orthogonalize + boundary filters, see :data:`ptwt.constants.OrthogonalizeMethod`. + Defaults to 'qr'. dtype (torch.dtype, optional): The data type of the sparse matrix, choose float32 or 64. Defaults to torch.float64. @@ -210,7 +217,9 @@ def construct_boundary_s2( wavelet = _as_wavelet(wavelet) s = _construct_s_2(wavelet, height, width, device, dtype=dtype) orth_s = orthogonalize( - s.transpose(1, 0), wavelet.rec_len**2, method=boundary # noqa: BLK100 + s.transpose(1, 0), + wavelet.rec_len**2, + method=boundary_orthogonalization, # noqa: BLK100 ).transpose(1, 0) return orth_s @@ -256,12 +265,13 @@ class MatrixWavedec2(BaseMatrixWaveDec): >>> mat_coeff = matrixfwt(pt_face) """ + @_deprecated_alias(boundary="boundary_orthogonalization") def __init__( self, wavelet: Union[Wavelet, str], level: Optional[int] = None, axes: tuple[int, int] = (-2, -1), - boundary: OrthogonalizeMethod = "qr", + boundary_orthogonalization: OrthogonalizeMethod = "qr", separable: bool = True, odd_coeff_padding_mode: BoundaryMode = "zero", ): @@ -277,8 +287,9 @@ def __init__( None. axes (int, int): A tuple with the axes to transform. Defaults to (-2, -1). - boundary : The method used for boundary filter treatment, - see :data:`ptwt.constants.OrthogonalizeMethod`. Defaults to 'qr'. + boundary_orthogonalization: The method used to orthogonalize + boundary filters, see :data:`ptwt.constants.OrthogonalizeMethod`. + Defaults to 'qr'. separable (bool): If this flag is set, a separable transformation is used, i.e. a 1d transformation along each axis. Matrix construction is significantly faster for separable @@ -301,7 +312,7 @@ def __init__( _check_axes_argument(list(axes)) self.axes = tuple(axes) self.level = level - self.boundary = boundary + self.boundary_orthogonalization = boundary_orthogonalization self.odd_coeff_padding_mode = odd_coeff_padding_mode self.separable = separable self.input_signal_shape: Optional[tuple[int, int]] = None @@ -311,7 +322,7 @@ def __init__( self.pad_list: list[tuple[bool, bool]] = [] self.padded = False - if not _is_boundary_mode_supported(self.boundary): + if not _is_boundary_mode_supported(self.boundary_orthogonalization): raise NotImplementedError if self.wavelet.dec_len != self.wavelet.rec_len: @@ -394,14 +405,14 @@ def _construct_analysis_matrices( analysis_matrix_rows = construct_boundary_a( wavelet=self.wavelet, length=current_height, - boundary=self.boundary, + boundary_orthogonalization=self.boundary_orthogonalization, device=device, dtype=dtype, ) analysis_matrix_cols = construct_boundary_a( wavelet=self.wavelet, length=current_width, - boundary=self.boundary, + boundary_orthogonalization=self.boundary_orthogonalization, device=device, dtype=dtype, ) @@ -413,7 +424,7 @@ def _construct_analysis_matrices( wavelet=self.wavelet, height=current_height, width=current_width, - boundary=self.boundary, + boundary_orthogonalization=self.boundary_orthogonalization, device=device, dtype=dtype, ) @@ -568,11 +579,12 @@ class MatrixWaverec2(object): >>> reconstruction = matrixifwt(mat_coeff) """ + @_deprecated_alias(boundary="boundary_orthogonalization") def __init__( self, wavelet: Union[Wavelet, str], axes: tuple[int, int] = (-2, -1), - boundary: OrthogonalizeMethod = "qr", + boundary_orthogonalization: OrthogonalizeMethod = "qr", separable: bool = True, ): """Create the inverse matrix-based fast wavelet transformation. @@ -584,8 +596,9 @@ def __init__( for possible choices. axes (int, int): The axes transformed by waverec2. Defaults to (-2, -1). - boundary : The method used for boundary filter treatment, - see :data:`ptwt.constants.OrthogonalizeMethod`. Defaults to 'qr'. + boundary_orthogonalization: The method used to orthogonalize + boundary filters, see :data:`ptwt.constants.OrthogonalizeMethod`. + Defaults to 'qr'. separable (bool): If this flag is set, a separable transformation is used, i.e. a 1d transformation along each axis. This is significantly faster than a non-separable transformation since only a small constant- @@ -598,7 +611,7 @@ def __init__( ValueError: If the wavelet filters have different lengths. """ self.wavelet = _as_wavelet(wavelet) - self.boundary = boundary + self.boundary_orthogonalization = boundary_orthogonalization self.separable = separable if len(axes) != 2: @@ -615,7 +628,7 @@ def __init__( self.padded = False - if not _is_boundary_mode_supported(self.boundary): + if not _is_boundary_mode_supported(self.boundary_orthogonalization): raise NotImplementedError if self.wavelet.dec_len != self.wavelet.rec_len: @@ -692,14 +705,14 @@ def _construct_synthesis_matrices( synthesis_matrix_rows = construct_boundary_s( wavelet=self.wavelet, length=current_height, - boundary=self.boundary, + boundary_orthogonalization=self.boundary_orthogonalization, device=device, dtype=dtype, ) synthesis_matrix_cols = construct_boundary_s( wavelet=self.wavelet, length=current_width, - boundary=self.boundary, + boundary_orthogonalization=self.boundary_orthogonalization, device=device, dtype=dtype, ) @@ -711,7 +724,7 @@ def _construct_synthesis_matrices( self.wavelet, current_height, current_width, - boundary=self.boundary, + boundary_orthogonalization=self.boundary_orthogonalization, device=device, dtype=dtype, ) diff --git a/src/ptwt/matmul_transform_3.py b/src/ptwt/matmul_transform_3.py index 247d73b7..df17fe3c 100644 --- a/src/ptwt/matmul_transform_3.py +++ b/src/ptwt/matmul_transform_3.py @@ -14,6 +14,7 @@ _as_wavelet, _check_axes_argument, _check_if_tensor, + _deprecated_alias, _fold_axes, _is_boundary_mode_supported, _is_dtype_supported, @@ -55,12 +56,13 @@ def _matrix_pad_3( class MatrixWavedec3(object): """Compute 3d separable transforms.""" + @_deprecated_alias(boundary="boundary_orthogonalization") def __init__( self, wavelet: Union[Wavelet, str], level: Optional[int] = None, axes: tuple[int, int, int] = (-3, -2, -1), - boundary: OrthogonalizeMethod = "qr", + boundary_orthogonalization: OrthogonalizeMethod = "qr", odd_coeff_padding_mode: BoundaryMode = "zero", ): """Create a *separable* three-dimensional fast boundary wavelet transform. @@ -75,8 +77,9 @@ def __init__( for possible choices. level (int, optional): The desired decomposition level. Defaults to None. - boundary : The method used for boundary filter treatment, - see :data:`ptwt.constants.OrthogonalizeMethod`. Defaults to 'qr'. + boundary_orthogonalization: The method used to orthogonalize + boundary filters, see :data:`ptwt.constants.OrthogonalizeMethod`. + Defaults to 'qr'. odd_coeff_padding_mode: The constructed FWT matrices require inputs with even lengths. Thus, any odd-length approximation coefficients are padded to an even length using this mode, @@ -91,7 +94,7 @@ def __init__( """ self.wavelet = _as_wavelet(wavelet) self.level = level - self.boundary = boundary + self.boundary_orthogonalization = boundary_orthogonalization self.odd_coeff_padding_mode = odd_coeff_padding_mode if len(axes) != 3: raise ValueError("3D transforms work with three axes.") @@ -101,7 +104,7 @@ def __init__( self.input_signal_shape: Optional[tuple[int, int, int]] = None self.fwt_matrix_list: list[list[torch.Tensor]] = [] - if not _is_boundary_mode_supported(self.boundary): + if not _is_boundary_mode_supported(self.boundary_orthogonalization): raise NotImplementedError if self.wavelet.dec_len != self.wavelet.rec_len: raise ValueError("All filters must have the same length") @@ -149,7 +152,7 @@ def _construct_analysis_matrices( matrix_construction_fun = partial( construct_boundary_a, wavelet=self.wavelet, - boundary=self.boundary, + boundary_orthogonalization=self.boundary_orthogonalization, device=device, dtype=dtype, ) @@ -286,11 +289,12 @@ def _split_rec( class MatrixWaverec3(object): """Reconstruct a signal from 3d-separable-fwt coefficients.""" + @_deprecated_alias(boundary="boundary_orthogonalization") def __init__( self, wavelet: Union[Wavelet, str], axes: tuple[int, int, int] = (-3, -2, -1), - boundary: OrthogonalizeMethod = "qr", + boundary_orthogonalization: OrthogonalizeMethod = "qr", ): """Compute a three-dimensional separable boundary wavelet synthesis transform. @@ -301,8 +305,9 @@ def __init__( for possible choices. axes (tuple[int, int, int]): Transform these axes instead of the last three. Defaults to (-3, -2, -1). - boundary : The method used for boundary filter treatment, - see :data:`ptwt.constants.OrthogonalizeMethod`. Defaults to 'qr'. + boundary_orthogonalization: The method used to orthogonalize + boundary filters, see :data:`ptwt.constants.OrthogonalizeMethod`. + Defaults to 'qr'. Raises: NotImplementedError: If the selected `boundary` mode is not supported. @@ -314,12 +319,12 @@ def __init__( else: _check_axes_argument(list(axes)) self.axes = axes - self.boundary = boundary + self.boundary_orthogonalization = boundary_orthogonalization self.ifwt_matrix_list: list[list[torch.Tensor]] = [] self.input_signal_shape: Optional[tuple[int, int, int]] = None self.level: Optional[int] = None - if not _is_boundary_mode_supported(self.boundary): + if not _is_boundary_mode_supported(self.boundary_orthogonalization): raise NotImplementedError if self.wavelet.dec_len != self.wavelet.rec_len: @@ -364,7 +369,7 @@ def _construct_synthesis_matrices( matrix_construction_fun = partial( construct_boundary_s, wavelet=self.wavelet, - boundary=self.boundary, + boundary_orthogonalization=self.boundary_orthogonalization, device=device, dtype=dtype, ) diff --git a/src/ptwt/packets.py b/src/ptwt/packets.py index 233f00d2..e92555d7 100644 --- a/src/ptwt/packets.py +++ b/src/ptwt/packets.py @@ -84,7 +84,7 @@ def __init__( The highest decomposition level to compute. If None, the maximum level is determined from the input data shape. Defaults to None. axis (int): The axis to transform. Defaults to -1. - boundary_orthogonalization : The orthogonalization method + boundary_orthogonalization: The orthogonalization method to use in the sparse matrix backend, see :data:`ptwt.constants.OrthogonalizeMethod`. Only used if `mode` equals 'boundary'. Defaults to 'qr'. @@ -108,7 +108,7 @@ def __init__( """ self.wavelet = _as_wavelet(wavelet) self.mode = mode - self.boundary = boundary_orthogonalization + self.boundary_orthogonalization = boundary_orthogonalization self._matrix_wavedec_dict: dict[int, MatrixWavedec] = {} self._matrix_waverec_dict: dict[int, MatrixWaverec] = {} self.maxlevel: Optional[int] = None @@ -183,7 +183,10 @@ def _get_wavedec( if self.mode == "boundary": if length not in self._matrix_wavedec_dict.keys(): self._matrix_wavedec_dict[length] = MatrixWavedec( - self.wavelet, level=1, boundary=self.boundary, axis=self.axis + self.wavelet, + level=1, + boundary_orthogonalization=self.boundary_orthogonalization, + axis=self.axis, ) return self._matrix_wavedec_dict[length] else: @@ -198,7 +201,9 @@ def _get_waverec( if self.mode == "boundary": if length not in self._matrix_waverec_dict.keys(): self._matrix_waverec_dict[length] = MatrixWaverec( - self.wavelet, boundary=self.boundary, axis=self.axis + self.wavelet, + boundary_orthogonalization=self.boundary_orthogonalization, + axis=self.axis, ) return self._matrix_waverec_dict[length] else: @@ -328,7 +333,7 @@ def __init__( """ self.wavelet = _as_wavelet(wavelet) self.mode = mode - self.boundary = boundary_orthogonalization + self.boundary_orthogonalization = boundary_orthogonalization self.separable = separable self.matrix_wavedec2_dict: dict[tuple[int, ...], MatrixWavedec2] = {} self.matrix_waverec2_dict: dict[tuple[int, ...], MatrixWaverec2] = {} @@ -413,7 +418,7 @@ def _get_wavedec(self, shape: tuple[int, ...]) -> Callable[ self.wavelet, level=1, axes=self.axes, - boundary=self.boundary, + boundary_orthogonalization=self.boundary_orthogonalization, separable=self.separable, ) fun = self.matrix_wavedec2_dict[shape] @@ -442,7 +447,7 @@ def _get_waverec( self.matrix_waverec2_dict[shape] = MatrixWaverec2( self.wavelet, axes=self.axes, - boundary=self.boundary, + boundary_orthogonalization=self.boundary_orthogonalization, separable=self.separable, ) return self.matrix_waverec2_dict[shape] diff --git a/tests/test_matrix_fwt.py b/tests/test_matrix_fwt.py index d95f2081..9d63a718 100644 --- a/tests/test_matrix_fwt.py +++ b/tests/test_matrix_fwt.py @@ -140,9 +140,9 @@ def test_boundary_transform_1d( """Ensure matrix fwt reconstructions are pywt compatible.""" data_torch = torch.from_numpy(data.astype(np.float64)) wavelet = pywt.Wavelet(wavelet_str) - matrix_wavedec = MatrixWavedec(wavelet, level=level, boundary=boundary) + matrix_wavedec = MatrixWavedec(wavelet, level=level, boundary_orthogonalization=boundary) coeffs = matrix_wavedec(data_torch) - matrix_waverec = MatrixWaverec(wavelet, boundary=boundary) + matrix_waverec = MatrixWaverec(wavelet, boundary_orthogonalization=boundary) rec = matrix_waverec(coeffs) rec_pywt = pywt.waverec( pywt.wavedec(data_torch.numpy(), wavelet, mode="zero"), wavelet @@ -172,9 +172,9 @@ def test_matrix_transform_1d_rebuild( """Ensure matrix fwt reconstructions are pywt compatible.""" data_list = [np.random.randn(18), np.random.randn(21)] wavelet = pywt.Wavelet(wavelet_str) - matrix_waverec = MatrixWaverec(wavelet, boundary=boundary) + matrix_waverec = MatrixWaverec(wavelet, boundary_orthogonalization=boundary) for level in [2, 1]: - matrix_wavedec = MatrixWavedec(wavelet, level=level, boundary=boundary) + matrix_wavedec = MatrixWavedec(wavelet, level=level, boundary_orthogonalization=boundary) for data in data_list: data_torch = torch.from_numpy(data.astype(np.float64)) coeffs = matrix_wavedec(data_torch)