Skip to content

Commit

Permalink
merge.
Browse files Browse the repository at this point in the history
  • Loading branch information
v0lta committed Jul 1, 2024
2 parents 6254f13 + b87482f commit e281c85
Show file tree
Hide file tree
Showing 12 changed files with 414 additions and 144 deletions.
64 changes: 61 additions & 3 deletions src/ptwt/_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,17 @@

from __future__ import annotations

import functools
import warnings
import typing
from collections.abc import Callable, Sequence
from functools import partial
from collections.abc import Callable, Sequence
from typing import Any, Literal, NamedTuple, Optional, Protocol, Union, cast, overload

import numpy as np
import pywt
import torch
from typing_extensions import ParamSpec, TypeVar

from .constants import (
OrthogonalizeMethod,
Expand Down Expand Up @@ -91,8 +94,10 @@ def _as_wavelet(wavelet: Union[Wavelet, str]) -> Wavelet:
return wavelet


def _is_boundary_mode_supported(boundary_mode: Optional[OrthogonalizeMethod]) -> bool:
return boundary_mode in typing.get_args(OrthogonalizeMethod)
def _is_orthogonalize_method_supported(
orthogonalization: Optional[OrthogonalizeMethod],
) -> bool:
return orthogonalization in typing.get_args(OrthogonalizeMethod)


def _is_dtype_supported(dtype: torch.dtype) -> bool:
Expand Down Expand Up @@ -630,4 +635,57 @@ def _postprocess_tensor(
"""
# interpreting data as the approximation coeffs of a 0-level FWT
# allows us to reuse the `_postprocess_coeffs` code
# return approx, *cast_result_lst
return _postprocess_coeffs(coeffs=[data], ndim=ndim, ds=ds, axes=axes)[0]


Param = ParamSpec("Param")
RetType = TypeVar("RetType")


def _deprecated_alias(
**aliases: str,
) -> Callable[[Callable[Param, RetType]], Callable[Param, RetType]]:
"""Handle deprecated function and method arguments.
Use as follows::
@_deprecated_alias(old_arg='new_arg')
def myfunc(new_arg):
...
Adapted from https://stackoverflow.com/a/49802489
"""

def rename_kwargs(
func_name: str,
kwargs: Param.kwargs,
aliases: dict[str, str],
) -> None:
"""Rename deprecated kwarg."""
for alias, new in aliases.items():
if alias in kwargs:
if new in kwargs:
raise TypeError(
f"{func_name} received both {alias} and {new} as arguments!"
f" {alias} is deprecated, use {new} instead."
)
warnings.warn(
message=(
f"`{alias}` is deprecated as an argument to `{func_name}`; use"
f" `{new}` instead."
),
category=DeprecationWarning,
stacklevel=3,
)
kwargs[new] = kwargs.pop(alias)

def deco(f: Callable[Param, RetType]) -> Callable[Param, RetType]:
@functools.wraps(f)
def wrapper(*args: Param.args, **kwargs: Param.kwargs) -> RetType:
rename_kwargs(f.__name__, kwargs, aliases)
return f(*args, **kwargs)

return wrapper

return deco
15 changes: 10 additions & 5 deletions src/ptwt/conv_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,7 @@ def _fwt_pad(
wavelet: Union[Wavelet, str],
*,
mode: Optional[BoundaryMode] = None,
padding: Optional[tuple[int, int]] = None,
) -> torch.Tensor:
"""Pad the input signal to make the fwt matrix work.
Expand All @@ -146,21 +147,25 @@ def _fwt_pad(
data (torch.Tensor): Input data ``[batch_size, 1, time]``
wavelet (Wavelet or str): A pywt wavelet compatible object or
the name of a pywt wavelet.
mode :
The desired padding mode for extending the signal along the edges.
mode: The desired padding mode for extending the signal along the edges.
Defaults to "reflect". See :data:`ptwt.constants.BoundaryMode`.
padding (tuple[int, int], optional): A tuple (padl, padr) with the
number of padded values on the left and right side of the last
axes of `data`. If None, the padding values are computed based
on the signal shape and the wavelet length. Defaults to None.
Returns:
A PyTorch tensor with the padded input data
"""
wavelet = _as_wavelet(wavelet)

# convert pywt to pytorch convention.
if mode is None:
mode = "reflect"
pytorch_mode = _translate_boundary_strings(mode)

padr, padl = _get_pad(data.shape[-1], _get_len(wavelet))
if padding is None:
padr, padl = _get_pad(data.shape[-1], _get_len(wavelet))
else:
padl, padr = padding
if pytorch_mode == "symmetric":
data_pad = _pad_symmetric(data, [(padl, padr)])
else:
Expand Down
18 changes: 13 additions & 5 deletions src/ptwt/conv_transform_2.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ def _fwt_pad2(
wavelet: Union[Wavelet, str],
*,
mode: Optional[BoundaryMode] = None,
padding: Optional[tuple[int, int, int, int]] = None,
) -> torch.Tensor:
"""Pad data for the 2d FWT.
Expand All @@ -72,9 +73,13 @@ def _fwt_pad2(
the name of a pywt wavelet.
Refer to the output from ``pywt.wavelist(kind='discrete')``
for possible choices.
mode :
The desired padding mode for extending the signal along the edges.
mode: The desired padding mode for extending the signal along the edges.
Defaults to "reflect". See :data:`ptwt.constants.BoundaryMode`.
padding (tuple[int, int, int, int], optional): A tuple
(padl, padr, padt, padb) with the number of padded values
on the left, right, top and bottom side of the last two
axes of `data`. If None, the padding values are computed based
on the signal shape and the wavelet length. Defaults to None.
Returns:
The padded output tensor.
Expand All @@ -83,9 +88,12 @@ def _fwt_pad2(
if mode is None:
mode = "reflect"
pytorch_mode = _translate_boundary_strings(mode)
wavelet = _as_wavelet(wavelet)
padb, padt = _get_pad(data.shape[-2], _get_len(wavelet))
padr, padl = _get_pad(data.shape[-1], _get_len(wavelet))

if padding is None:
padb, padt = _get_pad(data.shape[-2], _get_len(wavelet))
padr, padl = _get_pad(data.shape[-1], _get_len(wavelet))
else:
padl, padr, padt, padb = padding
if pytorch_mode == "symmetric":
data_pad = _pad_symmetric(data, [(padt, padb), (padl, padr)])
else:
Expand Down
25 changes: 18 additions & 7 deletions src/ptwt/conv_transform_3.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,11 @@ def _construct_3d_filt(lo: torch.Tensor, hi: torch.Tensor) -> torch.Tensor:


def _fwt_pad3(
data: torch.Tensor, wavelet: Union[Wavelet, str], *, mode: BoundaryMode
data: torch.Tensor,
wavelet: Union[Wavelet, str],
*,
mode: BoundaryMode,
padding: Optional[tuple[int, int, int, int, int, int]] = None,
) -> torch.Tensor:
"""Pad data for the 3d-FWT.
Expand All @@ -73,19 +77,26 @@ def _fwt_pad3(
the name of a pywt wavelet.
Refer to the output from ``pywt.wavelist(kind='discrete')``
for possible choices.
mode :
The desired padding mode for extending the signal along the edges.
mode: The desired padding mode for extending the signal along the edges.
See :data:`ptwt.constants.BoundaryMode`.
padding (tuple[int, int, int, int, int, int], optional): A tuple
(pad_left, pad_right, pad_top, pad_bottom, pad_front, pad_back)
with the number of padded values on the respective side of the
last three axes of `data`.
If None, the padding values are computed based
on the signal shape and the wavelet length. Defaults to None.
Returns:
The padded output tensor.
"""
pytorch_mode = _translate_boundary_strings(mode)

wavelet = _as_wavelet(wavelet)
pad_back, pad_front = _get_pad(data.shape[-3], _get_len(wavelet))
pad_bottom, pad_top = _get_pad(data.shape[-2], _get_len(wavelet))
pad_right, pad_left = _get_pad(data.shape[-1], _get_len(wavelet))
if padding is None:
pad_back, pad_front = _get_pad(data.shape[-3], _get_len(wavelet))
pad_bottom, pad_top = _get_pad(data.shape[-2], _get_len(wavelet))
pad_right, pad_left = _get_pad(data.shape[-1], _get_len(wavelet))
else:
pad_left, pad_right, pad_top, pad_bottom, pad_front, pad_back = padding
if pytorch_mode == "symmetric":
data_pad = _pad_symmetric(
data, [(pad_front, pad_back), (pad_top, pad_bottom), (pad_left, pad_right)]
Expand Down
Loading

0 comments on commit e281c85

Please sign in to comment.