Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow specification of odd length boundary padding in MatrixWavedec #97

Merged
merged 18 commits into from
Jun 30, 2024
Merged
Show file tree
Hide file tree
Changes from 14 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 57 additions & 2 deletions src/ptwt/_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,16 @@

from __future__ import annotations

import functools
import typing
from collections.abc import Sequence
from typing import Any, Callable, NamedTuple, Optional, Protocol, Union, cast, overload
import warnings
from collections.abc import Callable, Sequence
from typing import Any, NamedTuple, Optional, Protocol, Union, cast, overload

import numpy as np
import pywt
import torch
from typing_extensions import ParamSpec, TypeVar

from .constants import (
OrthogonalizeMethod,
Expand Down Expand Up @@ -253,3 +256,55 @@ def _map_result(
Union[list[WaveletDetailDict], list[WaveletDetailTuple2d]], result_lst
)
return approx, *cast_result_lst


Param = ParamSpec("Param")
RetType = TypeVar("RetType")


def _deprecated_alias(
**aliases: str,
) -> Callable[[Callable[Param, RetType]], Callable[Param, RetType]]:
"""Handle deprecated function and method arguments.

Use as follows::

@deprecated_alias(old_arg='new_arg')
def myfunc(new_arg):
...

Adapted from https://stackoverflow.com/a/49802489
"""

def rename_kwargs(
func_name: str,
kwargs: Param.kwargs,
aliases: dict[str, str],
) -> None:
"""Rename deprecated kwarg."""
for alias, new in aliases.items():
if alias in kwargs:
if new in kwargs:
raise TypeError(
f"{func_name} received both {alias} and {new} as arguments!"
f" {alias} is deprecated, use {new} instead."
)
warnings.warn(
message=(
f"`{alias}` is deprecated as an argument to `{func_name}`; use"
f" `{new}` instead."
),
category=DeprecationWarning,
stacklevel=3,
)
kwargs[new] = kwargs.pop(alias)

def deco(f: Callable[Param, RetType]) -> Callable[Param, RetType]:
@functools.wraps(f)
def wrapper(*args: Param.args, **kwargs: Param.kwargs) -> RetType:
rename_kwargs(f.__name__, kwargs, aliases)
return f(*args, **kwargs)

return wrapper

return deco
15 changes: 10 additions & 5 deletions src/ptwt/conv_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,7 @@ def _fwt_pad(
wavelet: Union[Wavelet, str],
*,
mode: Optional[BoundaryMode] = None,
padding: Optional[tuple[int, int]] = None,
) -> torch.Tensor:
"""Pad the input signal to make the fwt matrix work.

Expand All @@ -144,21 +145,25 @@ def _fwt_pad(
data (torch.Tensor): Input data ``[batch_size, 1, time]``
wavelet (Wavelet or str): A pywt wavelet compatible object or
the name of a pywt wavelet.
mode :
The desired padding mode for extending the signal along the edges.
mode: The desired padding mode for extending the signal along the edges.
Defaults to "reflect". See :data:`ptwt.constants.BoundaryMode`.
padding (tuple[int, int], optional): A tuple (padl, padr) with the
number of padded values on the left and right side of the last
axes of `data`. If None, the padding values are computed based
on the signal shape and the wavelet length. Defaults to None.

Returns:
A PyTorch tensor with the padded input data
"""
wavelet = _as_wavelet(wavelet)

# convert pywt to pytorch convention.
if mode is None:
mode = "reflect"
pytorch_mode = _translate_boundary_strings(mode)

padr, padl = _get_pad(data.shape[-1], _get_len(wavelet))
if padding is None:
padr, padl = _get_pad(data.shape[-1], _get_len(wavelet))
else:
padl, padr = padding
if pytorch_mode == "symmetric":
data_pad = _pad_symmetric(data, [(padl, padr)])
else:
Expand Down
18 changes: 13 additions & 5 deletions src/ptwt/conv_transform_2.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ def _fwt_pad2(
wavelet: Union[Wavelet, str],
*,
mode: Optional[BoundaryMode] = None,
padding: Optional[tuple[int, int, int, int]] = None,
) -> torch.Tensor:
"""Pad data for the 2d FWT.

Expand All @@ -76,9 +77,13 @@ def _fwt_pad2(
the name of a pywt wavelet.
Refer to the output from ``pywt.wavelist(kind='discrete')``
for possible choices.
mode :
The desired padding mode for extending the signal along the edges.
mode: The desired padding mode for extending the signal along the edges.
Defaults to "reflect". See :data:`ptwt.constants.BoundaryMode`.
padding (tuple[int, int, int, int], optional): A tuple
(padl, padr, padt, padb) with the number of padded values
on the left, right, top and bottom side of the last two
axes of `data`. If None, the padding values are computed based
on the signal shape and the wavelet length. Defaults to None.

Returns:
The padded output tensor.
Expand All @@ -87,9 +92,12 @@ def _fwt_pad2(
if mode is None:
mode = "reflect"
pytorch_mode = _translate_boundary_strings(mode)
wavelet = _as_wavelet(wavelet)
padb, padt = _get_pad(data.shape[-2], _get_len(wavelet))
padr, padl = _get_pad(data.shape[-1], _get_len(wavelet))

if padding is None:
padb, padt = _get_pad(data.shape[-2], _get_len(wavelet))
padr, padl = _get_pad(data.shape[-1], _get_len(wavelet))
else:
padl, padr, padt, padb = padding
if pytorch_mode == "symmetric":
data_pad = _pad_symmetric(data, [(padt, padb), (padl, padr)])
else:
Expand Down
25 changes: 18 additions & 7 deletions src/ptwt/conv_transform_3.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,11 @@ def _construct_3d_filt(lo: torch.Tensor, hi: torch.Tensor) -> torch.Tensor:


def _fwt_pad3(
data: torch.Tensor, wavelet: Union[Wavelet, str], *, mode: BoundaryMode
data: torch.Tensor,
wavelet: Union[Wavelet, str],
*,
mode: BoundaryMode,
padding: Optional[tuple[int, int, int, int, int, int]] = None,
) -> torch.Tensor:
"""Pad data for the 3d-FWT.

Expand All @@ -77,19 +81,26 @@ def _fwt_pad3(
the name of a pywt wavelet.
Refer to the output from ``pywt.wavelist(kind='discrete')``
for possible choices.
mode :
The desired padding mode for extending the signal along the edges.
mode: The desired padding mode for extending the signal along the edges.
See :data:`ptwt.constants.BoundaryMode`.
padding (tuple[int, int, int, int, int, int], optional): A tuple
(pad_left, pad_right, pad_top, pad_bottom, pad_front, pad_back)
with the number of padded values on the respective side of the
last three axes of `data`.
If None, the padding values are computed based
on the signal shape and the wavelet length. Defaults to None.

Returns:
The padded output tensor.
"""
pytorch_mode = _translate_boundary_strings(mode)

wavelet = _as_wavelet(wavelet)
pad_back, pad_front = _get_pad(data.shape[-3], _get_len(wavelet))
pad_bottom, pad_top = _get_pad(data.shape[-2], _get_len(wavelet))
pad_right, pad_left = _get_pad(data.shape[-1], _get_len(wavelet))
if padding is None:
pad_back, pad_front = _get_pad(data.shape[-3], _get_len(wavelet))
pad_bottom, pad_top = _get_pad(data.shape[-2], _get_len(wavelet))
pad_right, pad_left = _get_pad(data.shape[-1], _get_len(wavelet))
else:
pad_left, pad_right, pad_top, pad_bottom, pad_front, pad_back = padding
if pytorch_mode == "symmetric":
data_pad = _pad_symmetric(
data, [(pad_front, pad_back), (pad_top, pad_bottom), (pad_left, pad_right)]
Expand Down
Loading
Loading