Skip to content

Commit

Permalink
Rename boundary arg to boundary_orthogonalization
Browse files Browse the repository at this point in the history
  • Loading branch information
felixblanke committed Jun 26, 2024
1 parent 5862bb5 commit 91726be
Show file tree
Hide file tree
Showing 6 changed files with 158 additions and 69 deletions.
59 changes: 57 additions & 2 deletions src/ptwt/_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,16 @@

from __future__ import annotations

import functools
import typing
from collections.abc import Sequence
from typing import Any, Callable, NamedTuple, Optional, Protocol, Union, cast, overload
import warnings
from collections.abc import Callable, Sequence
from typing import Any, NamedTuple, Optional, Protocol, Union, cast, overload

import numpy as np
import pywt
import torch
from typing_extensions import ParamSpec, TypeVar

from .constants import (
OrthogonalizeMethod,
Expand Down Expand Up @@ -253,3 +256,55 @@ def _map_result(
Union[list[WaveletDetailDict], list[WaveletDetailTuple2d]], result_lst
)
return approx, *cast_result_lst


Param = ParamSpec("Param")
RetType = TypeVar("RetType")


def _deprecated_alias(
**aliases: str,
) -> Callable[[Callable[Param, RetType]], Callable[Param, RetType]]:
"""Handle deprecated function and method arguments.
Use as follows::
@deprecated_alias(old_arg='new_arg')
def myfunc(new_arg):
...
Adapted from https://stackoverflow.com/a/49802489
"""

def rename_kwargs(
func_name: str,
kwargs: Param.kwargs,
aliases: dict[str, str],
) -> None:
"""Rename deprecated kwarg."""
for alias, new in aliases.items():
if alias in kwargs:
if new in kwargs:
raise TypeError(
f"{func_name} received both {alias} and {new} as arguments!"
f" {alias} is deprecated, use {new} instead."
)
warnings.warn(
message=(
f"`{alias}` is deprecated as an argument to `{func_name}`; use"
f" `{new}` instead."
),
category=DeprecationWarning,
stacklevel=3,
)
kwargs[new] = kwargs.pop(alias)

def deco(f: Callable[Param, RetType]) -> Callable[Param, RetType]:
@functools.wraps(f)
def wrapper(*args: Param.args, **kwargs: Param.kwargs) -> RetType:
rename_kwargs(f.__name__, kwargs, aliases)
return f(*args, **kwargs)

return wrapper

return deco
51 changes: 31 additions & 20 deletions src/ptwt/matmul_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from ._util import (
Wavelet,
_as_wavelet,
_deprecated_alias,
_is_boundary_mode_supported,
_is_dtype_supported,
_unfold_axes,
Expand Down Expand Up @@ -182,12 +183,13 @@ class MatrixWavedec(BaseMatrixWaveDec):
>>> coefficients = matrix_wavedec(data_torch)
"""

@_deprecated_alias(boundary="boundary_orthogonalization")
def __init__(
self,
wavelet: Union[Wavelet, str],
level: Optional[int] = None,
axis: Optional[int] = -1,
boundary: OrthogonalizeMethod = "qr",
boundary_orthogonalization: OrthogonalizeMethod = "qr",
odd_coeff_padding_mode: BoundaryMode = "zero",
) -> None:
"""Create a sparse matrix fast wavelet transform object.
Expand All @@ -202,8 +204,9 @@ def __init__(
None.
axis (int, optional): The axis we would like to transform.
Defaults to -1.
boundary : The method used for boundary filter treatment,
see :data:`ptwt.constants.OrthogonalizeMethod`. Defaults to 'qr'.
boundary_orthogonalization: The method used to orthogonalize
boundary filters, see :data:`ptwt.constants.OrthogonalizeMethod`.
Defaults to 'qr'.
odd_coeff_padding_mode: The constructed FWT matrices require inputs
with even lengths. Thus, any odd-length approximation coefficients
are padded to an even length using this mode,
Expand All @@ -217,8 +220,8 @@ def __init__(
"""
self.wavelet = _as_wavelet(wavelet)
self.level = level
self.boundary = boundary
self.odd_coeff_padding_mode = odd_coeff_padding_mode
self.boundary_orthogonalization = boundary_orthogonalization

if isinstance(axis, int):
self.axis = axis
Expand All @@ -231,7 +234,7 @@ def __init__(
self.padded = False
self.size_list: list[int] = []

if not _is_boundary_mode_supported(self.boundary):
if not _is_boundary_mode_supported(self.boundary_orthogonalization):
raise NotImplementedError

if self.wavelet.dec_len != self.wavelet.rec_len:
Expand Down Expand Up @@ -311,7 +314,7 @@ def _construct_analysis_matrices(
an = construct_boundary_a(
self.wavelet,
curr_length,
boundary=self.boundary,
boundary_orthogonalization=self.boundary_orthogonalization,
device=device,
dtype=dtype,
)
Expand Down Expand Up @@ -402,11 +405,12 @@ def __call__(self, input_signal: torch.Tensor) -> list[torch.Tensor]:
return result_list


@_deprecated_alias(boundary="boundary_orthogonalization")
def construct_boundary_a(
wavelet: Union[Wavelet, str],
length: int,
device: Union[torch.device, str] = "cpu",
boundary: OrthogonalizeMethod = "qr",
boundary_orthogonalization: OrthogonalizeMethod = "qr",
dtype: torch.dtype = torch.float64,
) -> torch.Tensor:
"""Construct a boundary-wavelet filter 1d-analysis matrix.
Expand All @@ -415,8 +419,9 @@ def construct_boundary_a(
wavelet (Wavelet or str): A pywt wavelet compatible object or
the name of a pywt wavelet.
length (int): The number of entries in the input signal.
boundary : The method used for boundary filter treatment,
see :data:`ptwt.constants.OrthogonalizeMethod`. Defaults to 'qr'.
boundary_orthogonalization: The method used to orthogonalize
boundary filters, see :data:`ptwt.constants.OrthogonalizeMethod`.
Defaults to 'qr'.
device: Where to place the matrix. Choose cpu or cuda.
Defaults to cpu.
dtype: Choose float32 or float64.
Expand All @@ -426,15 +431,16 @@ def construct_boundary_a(
"""
wavelet = _as_wavelet(wavelet)
a_full = _construct_a(wavelet, length, dtype=dtype, device=device)
a_orth = orthogonalize(a_full, wavelet.dec_len, method=boundary)
a_orth = orthogonalize(a_full, wavelet.dec_len, method=boundary_orthogonalization)
return a_orth


@_deprecated_alias(boundary="boundary_orthogonalization")
def construct_boundary_s(
wavelet: Union[Wavelet, str],
length: int,
device: Union[torch.device, str] = "cpu",
boundary: OrthogonalizeMethod = "qr",
boundary_orthogonalization: OrthogonalizeMethod = "qr",
dtype: torch.dtype = torch.float64,
) -> torch.Tensor:
"""Construct a boundary-wavelet filter 1d-synthesis matarix.
Expand All @@ -445,8 +451,9 @@ def construct_boundary_s(
length (int): The number of entries in the input signal.
device (torch.device): Where to place the matrix.
Choose cpu or cuda. Defaults to cpu.
boundary : The method used for boundary filter treatment,
see :data:`ptwt.constants.OrthogonalizeMethod`. Defaults to 'qr'.
boundary_orthogonalization: The method used to orthogonalize
boundary filters, see :data:`ptwt.constants.OrthogonalizeMethod`.
Defaults to 'qr'.
dtype: Choose torch.float32 or torch.float64.
Defaults to torch.float64.
Expand All @@ -455,7 +462,9 @@ def construct_boundary_s(
"""
wavelet = _as_wavelet(wavelet)
s_full = _construct_s(wavelet, length, dtype=dtype, device=device)
s_orth = orthogonalize(s_full.transpose(1, 0), wavelet.rec_len, method=boundary)
s_orth = orthogonalize(
s_full.transpose(1, 0), wavelet.rec_len, method=boundary_orthogonalization
)
return s_orth.transpose(1, 0)


Expand All @@ -476,11 +485,12 @@ class MatrixWaverec(object):
>>> reconstruction = matrix_waverec(coefficients)
"""

@_deprecated_alias(boundary="boundary_orthogonalization")
def __init__(
self,
wavelet: Union[Wavelet, str],
axis: int = -1,
boundary: OrthogonalizeMethod = "qr",
boundary_orthogonalization: OrthogonalizeMethod = "qr",
) -> None:
"""Create the inverse matrix-based fast wavelet transformation.
Expand All @@ -491,16 +501,17 @@ def __init__(
for possible choices.
axis (int): The axis transformed by the original decomposition
defaults to -1 or the last axis.
boundary : The method used for boundary filter treatment,
see :data:`ptwt.constants.OrthogonalizeMethod`. Defaults to 'qr'.
boundary_orthogonalization: The method used to orthogonalize
boundary filters, see :data:`ptwt.constants.OrthogonalizeMethod`.
Defaults to 'qr'.
Raises:
NotImplementedError: If the selected `boundary` mode is not supported.
ValueError: If the wavelet filters have different lengths or if
axis is not an integer.
"""
self.wavelet = _as_wavelet(wavelet)
self.boundary = boundary
self.boundary_orthogonalization = boundary_orthogonalization
if isinstance(axis, int):
self.axis = axis
else:
Expand All @@ -511,7 +522,7 @@ def __init__(
self.input_length: Optional[int] = None
self.padded = False

if not _is_boundary_mode_supported(self.boundary):
if not _is_boundary_mode_supported(self.boundary_orthogonalization):
raise NotImplementedError

if self.wavelet.dec_len != self.wavelet.rec_len:
Expand Down Expand Up @@ -591,7 +602,7 @@ def _construct_synthesis_matrices(
sn = construct_boundary_s(
self.wavelet,
curr_length,
boundary=self.boundary,
boundary_orthogonalization=self.boundary_orthogonalization,
device=device,
dtype=dtype,
)
Expand Down
Loading

0 comments on commit 91726be

Please sign in to comment.