Skip to content

Commit

Permalink
Use typing.Literal for boundary mode, padding mode, and orthogonaliza…
Browse files Browse the repository at this point in the history
…tion mode (#77)
  • Loading branch information
cthoyt authored Jan 26, 2024
1 parent 6565e60 commit f390042
Show file tree
Hide file tree
Showing 18 changed files with 227 additions and 164 deletions.
7 changes: 6 additions & 1 deletion docs/ptwt.rst
Original file line number Diff line number Diff line change
Expand Up @@ -102,4 +102,9 @@ ptwt.wavelets\_learnable module
:undoc-members:
:show-inheritance:


ptwt.constants
-------------------------------
.. automodule:: ptwt.constants
:members:
:undoc-members:
:show-inheritance:
2 changes: 1 addition & 1 deletion examples/speed_tests/timeitconv_2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def _to_jit_wavedec_2(data, wavelet):
means we have to stack the lists in the output.
"""
assert data.shape == (32, 1e3, 1e3), "Changing the chape requires re-tracing."
coeff = ptwt.wavedec2(data, wavelet, "periodic", level=5)
coeff = ptwt.wavedec2(data, wavelet, mode="periodic", level=5)
coeff2 = []
for c in coeff:
if isinstance(c, torch.Tensor):
Expand Down
7 changes: 5 additions & 2 deletions src/ptwt/_util.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
"""Utility methods to compute wavelet decompositions from a dataset."""

import typing
from typing import Any, Callable, List, Optional, Protocol, Sequence, Tuple, Union

import numpy as np
import pywt
import torch

from ptwt.constants import OrthogonalizeMethod


class Wavelet(Protocol):
"""Wavelet object interface, based on the pywt wavelet object."""
Expand Down Expand Up @@ -43,8 +46,8 @@ def _as_wavelet(wavelet: Union[Wavelet, str]) -> Wavelet:
return wavelet


def _is_boundary_mode_supported(boundary_mode: Optional[str]) -> bool:
return boundary_mode in ["qr", "gramschmidt"]
def _is_boundary_mode_supported(boundary_mode: Optional[OrthogonalizeMethod]) -> bool:
return boundary_mode in typing.get_args(OrthogonalizeMethod)


def _is_dtype_supported(dtype: torch.dtype) -> bool:
Expand Down
38 changes: 38 additions & 0 deletions src/ptwt/constants.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
"""Constants and types used throughout the PyTorch Wavelet Toolbox."""

from typing import Literal, Union

__all__ = [
"BoundaryMode",
"ExtendedBoundaryMode",
"PaddingMode",
"OrthogonalizeMethod",
]

BoundaryMode = Literal["constant", "zero", "reflect", "periodic", "symmetric"]
"""
This is a type literal for the way of padding.
- Refection padding mirrors samples along the border.
- Zero padding pads zeros.
- Constant padding replicates border values.
- Periodic padding cyclically repeats samples.
- Symmetric padding mirrors samples along the border
"""

ExtendedBoundaryMode = Union[Literal["boundary"], BoundaryMode]

PaddingMode = Literal["full", "valid", "same", "sameshift"]
"""
The padding mode is used when construction convolution matrices.
"""

OrthogonalizeMethod = Literal["qr", "gramschmidt"]
"""
The method for orthogonalizing a matrix.
1. 'qr' relies on pytorch's dense qr implementation, it is fast but memory hungry.
2. 'gramschmidt' option is sparse, memory efficient, and slow.
Choose 'gramschmidt' if 'qr' runs out of memory.
"""
63 changes: 27 additions & 36 deletions src/ptwt/conv_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
"""

# Created by moritz wolter, 14.04.20
from typing import List, Optional, Sequence, Tuple, Union
from typing import List, Optional, Sequence, Tuple, Union, cast

import pywt
import torch
Expand All @@ -18,6 +18,7 @@
_pad_symmetric,
_unfold_axes,
)
from .constants import BoundaryMode


def _create_tensor(
Expand Down Expand Up @@ -106,7 +107,7 @@ def _get_pad(data_len: int, filt_len: int) -> Tuple[int, int]:
return padr, padl


def _translate_boundary_strings(pywt_mode: str) -> str:
def _translate_boundary_strings(pywt_mode: BoundaryMode) -> str:
"""Translate pywt mode strings to PyTorch mode strings.
We support constant, zero, reflect, and periodic.
Expand All @@ -118,24 +119,25 @@ def _translate_boundary_strings(pywt_mode: str) -> str:
"""
if pywt_mode == "constant":
pt_mode = "replicate"
return "replicate"
elif pywt_mode == "zero":
pt_mode = "constant"
return "constant"
elif pywt_mode == "reflect":
pt_mode = pywt_mode
return pywt_mode
elif pywt_mode == "periodic":
pt_mode = "circular"
return "circular"
elif pywt_mode == "symmetric":
# pytorch does not support symmetric mode,
# we have our own implementation.
pt_mode = pywt_mode
else:
raise ValueError("Padding mode not supported.")
return pt_mode
return pywt_mode
raise ValueError(f"Padding mode not supported: {pywt_mode}")


def _fwt_pad(
data: torch.Tensor, wavelet: Union[Wavelet, str], mode: str = "reflect"
data: torch.Tensor,
wavelet: Union[Wavelet, str],
*,
mode: Optional[BoundaryMode] = None,
) -> torch.Tensor:
"""Pad the input signal to make the fwt matrix work.
Expand All @@ -145,29 +147,26 @@ def _fwt_pad(
data (torch.Tensor): Input data ``[batch_size, 1, time]``
wavelet (Wavelet or str): A pywt wavelet compatible object or
the name of a pywt wavelet.
mode (str): The desired way to pad. The following methods are supported::
"reflect", "zero", "constant", "periodic", "symmetric".
Refection padding mirrors samples along the border.
Zero padding pads zeros.
Constant padding replicates border values.
Periodic padding cyclically repeats samples.
This function defaults to reflect.
mode :
The desired padding mode for extending the signal along the edges.
Defaults to "reflect". See :data:`ptwt.constants.BoundaryMode`.
Returns:
torch.Tensor: A PyTorch tensor with the padded input data
"""
wavelet = _as_wavelet(wavelet)

# convert pywt to pytorch convention.
mode = _translate_boundary_strings(mode)
if mode is None:
mode = cast(BoundaryMode, "reflect")
pytorch_mode = _translate_boundary_strings(mode)

padr, padl = _get_pad(data.shape[-1], _get_len(wavelet))
if mode == "symmetric":
if pytorch_mode == "symmetric":
data_pad = _pad_symmetric(data, [(padl, padr)])
else:
data_pad = torch.nn.functional.pad(data, [padl, padr], mode=mode)
data_pad = torch.nn.functional.pad(data, [padl, padr], mode=pytorch_mode)
return data_pad


Expand Down Expand Up @@ -263,7 +262,8 @@ def _preprocess_result_list_rec1d(
def wavedec(
data: torch.Tensor,
wavelet: Union[Wavelet, str],
mode: str = "reflect",
*,
mode: BoundaryMode = "reflect",
level: Optional[int] = None,
axis: int = -1,
) -> List[torch.Tensor]:
Expand All @@ -276,18 +276,9 @@ def wavedec(
the name of a pywt wavelet.
Please consider the output from ``pywt.wavelist(kind='discrete')``
for possible choices.
mode (str): The desired padding mode. Padding extends the signal along
the edges. Supported methods are::
"reflect", "zero", "constant", "periodic", "symmetric".
Defaults to "reflect".
Symmetric padding mirrors samples along the border.
Refection padding reflects samples along the border.
Zero padding pads zeros.
Constant padding replicates border values.
Periodic padding cyclically repeats samples.
mode :
The desired padding mode for extending the signal along the edges.
Defaults to "reflect". See :data:`ptwt.constants.BoundaryMode`.
level (int): The scale level to be computed.
Defaults to None.
axis (int): Compute the transform over this axis instead of the
Expand Down
38 changes: 21 additions & 17 deletions src/ptwt/conv_transform_2.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
"""

from functools import partial
from typing import List, Optional, Tuple, Union
from typing import List, Optional, Tuple, Union, cast

import pywt
import torch
Expand All @@ -25,6 +25,7 @@
_undo_swap_axes,
_unfold_axes,
)
from .constants import BoundaryMode
from .conv_transform import (
_adjust_padding_at_reconstruction,
_get_filter_tensors,
Expand Down Expand Up @@ -58,7 +59,10 @@ def _construct_2d_filt(lo: torch.Tensor, hi: torch.Tensor) -> torch.Tensor:


def _fwt_pad2(
data: torch.Tensor, wavelet: Union[Wavelet, str], mode: str = "reflect"
data: torch.Tensor,
wavelet: Union[Wavelet, str],
*,
mode: Optional[BoundaryMode] = None,
) -> torch.Tensor:
"""Pad data for the 2d FWT.
Expand All @@ -68,25 +72,26 @@ def _fwt_pad2(
data (torch.Tensor): Input data with 4 dimensions.
wavelet (Wavelet or str): A pywt wavelet compatible object or
the name of a pywt wavelet.
mode (str): The padding mode.
Supported modes are::
"reflect", "zero", "constant", "periodic", "symmetric".
"reflect" is the default mode.
mode :
The desired padding mode for extending the signal along the edges.
Defaults to "reflect". See :data:`ptwt.constants.BoundaryMode`.
Returns:
The padded output tensor.
"""
mode = _translate_boundary_strings(mode)
if mode is None:
mode = cast(BoundaryMode, "reflect")
pytorch_mode = _translate_boundary_strings(mode)
wavelet = _as_wavelet(wavelet)
padb, padt = _get_pad(data.shape[-2], _get_len(wavelet))
padr, padl = _get_pad(data.shape[-1], _get_len(wavelet))
if mode == "symmetric":
if pytorch_mode == "symmetric":
data_pad = _pad_symmetric(data, [(padt, padb), (padl, padr)])
else:
data_pad = torch.nn.functional.pad(data, [padl, padr, padt, padb], mode=mode)
data_pad = torch.nn.functional.pad(
data, [padl, padr, padt, padb], mode=pytorch_mode
)
return data_pad


Expand Down Expand Up @@ -122,7 +127,8 @@ def _preprocess_tensor_dec2d(
def wavedec2(
data: torch.Tensor,
wavelet: Union[Wavelet, str],
mode: str = "reflect",
*,
mode: BoundaryMode = "reflect",
level: Optional[int] = None,
axes: Tuple[int, int] = (-2, -1),
) -> List[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]]:
Expand All @@ -140,11 +146,9 @@ def wavedec2(
wavelet (Wavelet or str): A pywt wavelet compatible object or
the name of a pywt wavelet. Refer to the output of
``pywt.wavelist(kind="discrete")`` for a list of possible choices.
mode (str): The padding mode. Options are::
"reflect", "zero", "constant", "periodic", "symmetric".
This function defaults to "reflect".
mode :
The desired padding mode for extending the signal along the edges.
Defaults to "reflect". See :data:`ptwt.constants.BoundaryMode`.
level (int): The number of desired scales.
Defaults to None.
axes (Tuple[int, int]): Compute the transform over these axes instead of the
Expand Down
26 changes: 13 additions & 13 deletions src/ptwt/conv_transform_3.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
_undo_swap_axes,
_unfold_axes,
)
from .constants import BoundaryMode
from .conv_transform import (
_adjust_padding_at_reconstruction,
_get_filter_tensors,
Expand Down Expand Up @@ -63,7 +64,7 @@ def _construct_3d_filt(lo: torch.Tensor, hi: torch.Tensor) -> torch.Tensor:


def _fwt_pad3(
data: torch.Tensor, wavelet: Union[Wavelet, str], mode: str
data: torch.Tensor, wavelet: Union[Wavelet, str], *, mode: BoundaryMode
) -> torch.Tensor:
"""Pad data for the 3d-FWT.
Expand All @@ -73,37 +74,38 @@ def _fwt_pad3(
data (torch.Tensor): Input data with 4 dimensions.
wavelet (Wavelet or str): A pywt wavelet compatible object or
the name of a pywt wavelet.
mode (str): The padding mode. Supported modes are::
"reflect", "zero", "constant", "periodic", "symmetric".
mode :
The desired padding mode for extending the signal along the edges.
See :data:`ptwt.constants.BoundaryMode`.
Returns:
The padded output tensor.
"""
mode = _translate_boundary_strings(mode)
pytorch_mode = _translate_boundary_strings(mode)

wavelet = _as_wavelet(wavelet)
pad_back, pad_front = _get_pad(data.shape[-3], _get_len(wavelet))
pad_bottom, pad_top = _get_pad(data.shape[-2], _get_len(wavelet))
pad_right, pad_left = _get_pad(data.shape[-1], _get_len(wavelet))
if mode == "symmetric":
if pytorch_mode == "symmetric":
data_pad = _pad_symmetric(
data, [(pad_front, pad_back), (pad_top, pad_bottom), (pad_left, pad_right)]
)
else:
data_pad = torch.nn.functional.pad(
data,
[pad_left, pad_right, pad_top, pad_bottom, pad_front, pad_back],
mode=mode,
mode=pytorch_mode,
)
return data_pad


def wavedec3(
data: torch.Tensor,
wavelet: Union[Wavelet, str],
mode: str = "zero",
*,
mode: BoundaryMode = "zero",
level: Optional[int] = None,
axes: Tuple[int, int, int] = (-3, -2, -1),
) -> List[Union[torch.Tensor, Dict[str, torch.Tensor]]]:
Expand All @@ -114,11 +116,9 @@ def wavedec3(
[batch_size, length, height, width]
wavelet (Union[Wavelet, str]): The wavelet to transform with.
``pywt.wavelist(kind='discrete')`` lists possible choices.
mode (str): The padding mode. Possible options are::
"reflect", "zero", "constant", "periodic", "symmetric".
Defaults to "zero".
mode :
The desired padding mode for extending the signal along the edges.
Defaults to "zero". See :data:`ptwt.constants.BoundaryMode`.
level (Optional[int]): The maximum decomposition level.
This argument defaults to None.
axes (Tuple[int, int, int]): Compute the transform over these axes
Expand Down
Loading

0 comments on commit f390042

Please sign in to comment.