Skip to content

Commit

Permalink
Go over docs
Browse files Browse the repository at this point in the history
  • Loading branch information
felixblanke committed Jul 2, 2024
1 parent 9d5a779 commit 3fc692d
Show file tree
Hide file tree
Showing 15 changed files with 214 additions and 149 deletions.
1 change: 1 addition & 0 deletions docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,4 +88,5 @@
"BaseMatrixWaveDec": "ptwt.matmul_transform.BaseMatrixWaveDec",
"BoundaryMode": "ptwt.constants.BoundaryMode",
"ExtendedBoundaryMode": "ptwt.constants.ExtendedBoundaryMode",
"OrthogonalizeMethod": "ptwt.constants.OrthogonalizeMethod"
}
12 changes: 10 additions & 2 deletions docs/ref/boundary.rst
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,24 @@
Boundary handling modes
=======================

As is typical the algorithms in this toolbox are designed to be applied to signals of finite size.
As is typical the algorithms in this toolbox are designed to be applied
to signal tensors of finite size.
This requires some handling of the signal boundaries to apply the
wavelet transform convolutions.

.. TODO: Add explanation page on signal extension / boundary wavelets
This toolbox implements two different approaches to boundary handling:

* signal extension via padding
* using boundary filters for coeffients on the signal boundary

Signal extension via padding
----------------------------

.. _`modes.padding`:

Signal extensions by padding are applied using :func:`torch.nn.functional.pad`.
The following modes of padding are supported:

.. autoclass:: BoundaryMode


Expand Down
2 changes: 1 addition & 1 deletion docs/ref/return-types.rst
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ Transforms in two dimensions

.. autoclass:: WaveletDetailTuple2d
:members:
:undoc-members:
:class-doc-from: class
:show-inheritance:
:member-order: bysource

Expand Down
5 changes: 3 additions & 2 deletions src/ptwt/_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,9 @@
def _translate_boundary_strings(pywt_mode: BoundaryMode) -> str:
"""Translate pywt mode strings to PyTorch mode strings.
We support constant, zero, reflect, periodic and symmetric.
Unfortunately, "constant" has different meanings in the
We support ``constant``, ``zero``, ``reflect``,
``periodic`` and ``symmetric``.
Unfortunately, ``constant`` has different meanings in the
Pytorch and PyWavelet communities.
Raises:
Expand Down
31 changes: 25 additions & 6 deletions src/ptwt/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,11 +85,25 @@ def from_wavelet(cls, wavelet: Wavelet, dtype: torch.dtype) -> WaveletTensorTupl
"""
This is a type literal for the way of padding used at boundaries.
- Refection padding mirrors samples along the border (``reflect``)
- Zero padding pads zeros (``zero``)
- Constant padding replicates border values (``constant``)
- Periodic padding cyclically repeats samples (``periodic``)
- Symmetric padding mirrors samples along the border (``symmetric``)
- ``reflect``: Refection padding reflects samples at the border::
... x3 x2 | x1 x2 ... xn | xn-1 xn-2 ...
- ``zero``: Zero padding extends the signal with zeros::
... 0 0 | x1 x2 ... xn | 0 0 ...
- ``constant``: Constant padding replicates border values::
... x1 x1 | x1 x2 ... xn | xn xn ...
- ``periodic``: Periodic padding cyclically repeats samples::
... xn-1 xn | x1 x2 ... xn | x1 x2 ...
- ``symmetric``: Symmetric padding mirrors samples along the border::
... x2 x1 | x1 x2 ... xn | xn xn-1 ...
"""

ExtendedBoundaryMode = Union[Literal["boundary"], BoundaryMode]
Expand Down Expand Up @@ -153,6 +167,11 @@ class WaveletDetailTuple2d(NamedTuple):
This is a type alias for a named tuple ``(H, V, D)`` of detail coefficient tensors
where ``H`` denotes horizontal, ``V`` vertical and ``D`` diagonal coefficients.
We follow the pywt convention for the orientation of axes , i.e.
axis 0 is horizontal and axis 1 vertical.
For more information, see the
`pywt docs <https://pywavelets.readthedocs.io/en/latest/ref/2d-dwt-and-idwt.html#d-coordinate-conventions>`_.
"""

horizontal: torch.Tensor
Expand Down Expand Up @@ -191,7 +210,7 @@ class WaveletDetailTuple2d(NamedTuple):
of length :math:`n + 1`.
``cAn`` denotes a tensor of approximation coefficients for the `n`-th level
of decomposition. ``Tl`` is a tuple of detail coefficients for level ``l``,
see :data:`ptwt.constants.WaveletDetailTuple2d`.
see :class:`ptwt.constants.WaveletDetailTuple2d`.
Note that this type always contains an approximation coefficient tensor but does not
necesseraily contain any detail coefficients.
Expand Down
4 changes: 2 additions & 2 deletions src/ptwt/conv_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,14 +99,13 @@ def wavedec(
Please consider the output from ``pywt.wavelist(kind='discrete')``
for possible choices.
mode: The desired padding mode for extending the signal along the edges.
See :data:`ptwt.constants.BoundaryMode`. Defaults to "reflect".
See :data:`ptwt.constants.BoundaryMode`. Defaults to ``reflect``.
level (int, optional): The maximum decomposition level.
If None, the level is computed based on the signal shape.
Defaults to None.
axis (int): Compute the transform over this axis of the `data` tensor.
Defaults to -1.
Returns:
A list::
Expand Down Expand Up @@ -166,6 +165,7 @@ def waverec(
Returns:
The reconstructed signal tensor.
Its shape depends on the shape of the input to :func:`ptwt.wavedec`.
Example:
>>> import ptwt, torch
Expand Down
23 changes: 9 additions & 14 deletions src/ptwt/conv_transform_2.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ def wavedec2(
level: Optional[int] = None,
axes: tuple[int, int] = (-2, -1),
) -> WaveletCoeff2d:
r"""Run a two-dimensional fast wavelet transformation.
r"""Compute the two-dimensional fast wavelet transformation.
This function relies on two-dimensional convolutions.
Outer products allow the construction of 2d filters
Expand All @@ -131,17 +131,13 @@ def wavedec2(
Args:
data (torch.Tensor): The input data tensor with at least two dimensions.
By default 2d inputs are interpreted as ``[height, width]``,
3d inputs are interpreted as ``[batch_size, height, width]``.
4d inputs are interpreted as ``[batch_size, channels, height, width]``.
The ``axes`` argument allows other interpretations.
By default, the last two axes are transformed.
wavelet (Wavelet or str): A pywt wavelet compatible object or
the name of a pywt wavelet.
Refer to the output from ``pywt.wavelist(kind='discrete')``
for possible choices.
mode :
The desired padding mode for extending the signal along the edges.
See :data:`ptwt.constants.BoundaryMode`. Defaults to "reflect".
mode: The desired padding mode for extending the signal along the edges.
See :data:`ptwt.constants.BoundaryMode`. Defaults to ``reflect``.
level (int, optional): The maximum decomposition level.
If None, the level is computed based on the signal shape.
Defaults to None.
Expand Down Expand Up @@ -199,7 +195,7 @@ def waverec2(
"""Reconstruct a 2d signal from wavelet coefficients.
Args:
coeffs: The wavelet coefficient tuple produced by :data:`ptwt.wavedec2`.
coeffs: The wavelet coefficient tuple produced by :func:`ptwt.wavedec2`.
See :data:`ptwt.constants.WaveletCoeff2d`
wavelet (Wavelet or str): A pywt wavelet compatible object or
the name of a pywt wavelet.
Expand All @@ -209,13 +205,12 @@ def waverec2(
tensor. Defaults to (-2, -1).
Returns:
The reconstructed signal tensor of shape ``[batch, height, width]`` or
``[batch, channel, height, width]`` depending on the input
to :data:`ptwt.wavedec2`.
The reconstructed signal tensor.
Its shape depends on the shape of the input to :func:`ptwt.wavedec2`.
Raises:
ValueError: If coeffs is not in a shape as returned from
:data:`ptwt.wavedec2` or if the dtype is not supported or
ValueError: If `coeffs` is not in a shape as returned from
:func:`ptwt.wavedec2` or if the dtype is not supported or
if the provided axes input has length other
than two or if the same axes it repeated twice.
Expand Down
18 changes: 9 additions & 9 deletions src/ptwt/conv_transform_3.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,18 +115,18 @@ def wavedec3(
level: Optional[int] = None,
axes: tuple[int, int, int] = (-3, -2, -1),
) -> WaveletCoeffNd:
"""Compute a three-dimensional wavelet transform.
"""Compute the three-dimensional fast wavelet transformation.
Args:
data (torch.Tensor): The input data. For example of shape
``[batch_size, length, height, width]``
data (torch.Tensor): The input data tensor with at least three dimensions.
By default, the last three axes are transformed.
wavelet (Wavelet or str): A pywt wavelet compatible object or
the name of a pywt wavelet.
Refer to the output from ``pywt.wavelist(kind='discrete')``
for possible choices.
mode: The desired padding mode for extending the signal
along the edges. See :data:`ptwt.constants.BoundaryMode`.
Defaults to "zero".
Defaults to ``zero``.
level (int, optional): The maximum decomposition level.
If None, the level is computed based on the signal shape.
Defaults to None.
Expand Down Expand Up @@ -190,7 +190,7 @@ def waverec3(
"""Reconstruct a 3d signal from wavelet coefficients.
Args:
coeffs: The wavelet coefficient tuple produced by :data:`ptwt.wavedec3`,
coeffs: The wavelet coefficient tuple produced by :func:`ptwt.wavedec3`,
see :data:`ptwt.constants.WaveletCoeffNd`.
wavelet (Wavelet or str): A pywt wavelet compatible object or
the name of a pywt wavelet.
Expand All @@ -200,12 +200,12 @@ def waverec3(
tensor. Defaults to (-3, -2, -1).
Returns:
The reconstructed four-dimensional signal tensor of shape
``[batch, depth, height, width]``.
The reconstructed signal tensor.
Its shape depends on the shape of the input to :func:`ptwt.wavedec3`.
Raises:
ValueError: If coeffs is not in a shape as returned
from :data:`ptwt.wavedec3` or if the dtype is not supported or
ValueError: If `coeffs` is not in a shape as returned
from :func:`ptwt.wavedec3` or if the dtype is not supported or
if the provided axes input has length other than three or
if the same axes it repeated three.
Expand Down
46 changes: 27 additions & 19 deletions src/ptwt/matmul_transform.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
"""Implement matrix-based fwt and ifwt.
"""Implement matrix-based FWT and iFWT.
This module uses boundary filters instead of padding.
Expand Down Expand Up @@ -169,7 +169,7 @@ class MatrixWavedec(BaseMatrixWaveDec):
"""Compute the 1d fast wavelet transform using sparse matrices.
This transform is the sparse matrix correspondant to
:data:`ptwt.wavedec`. The convolution operations are
:func:`ptwt.wavedec`. The convolution operations are
implemented as a matrix-vector product between a
sparse transformation matrix and the input signal.
This transform uses boundary wavelets instead of padding to
Expand All @@ -183,12 +183,13 @@ class MatrixWavedec(BaseMatrixWaveDec):
The matrix is therefore constructed only once and reused
in further calls.
The sparse transformation matrix can be accessed
via the :data:`sparse_fwt_operator` property.
via the :attr:`sparse_fwt_operator` property.
Note:
On each level of the transform the convolved signal
is required to be of even length. This transform uses
zero padding to transform coefficients with an odd length.
padding to transform coefficients with an odd length,
with the padding mode specified by `odd_coeff_padding_mode`.
To avoid padding consider transforming signals
with a length divisable by :math:`2^L`
for a :math:`L`-level transform.
Expand Down Expand Up @@ -220,18 +221,18 @@ def __init__(
the name of a pywt wavelet.
Refer to the output from ``pywt.wavelist(kind='discrete')``
for possible choices.
level (int, optional): The level up to which to compute the fwt. If None,
level (int, optional): The level up to which to compute the FWT. If None,
the maximum level based on the signal length is chosen. Defaults to
None.
axis (int): The axis we would like to transform. Defaults to -1.
orthogonalization: The method used to orthogonalize
boundary filters, see :data:`ptwt.constants.OrthogonalizeMethod`.
Defaults to 'qr'.
Defaults to ``qr``.
odd_coeff_padding_mode: The constructed FWT matrices require inputs
with even lengths. Thus, any odd-length approximation coefficients
are padded to an even length using this mode,
see :data:`ptwt.constants.BoundaryMode`.
Defaults to 'zero'.
Defaults to ``zero``.
.. versionchanged:: 1.10
The argument `boundary` has been renamed to `orthogonalization`.
Expand Down Expand Up @@ -272,8 +273,13 @@ def sparse_fwt_operator(self) -> torch.Tensor:
the whole operation is padding-free and can be expressed
as a single matrix multiply.
The operation ``torch.sparse.mm(sparse_fwt_operator, data.T)``
computes a batched fwt.
The operation
.. code-block:: python
torch.sparse.mm(sparse_fwt_operator, data.T)
computes a batched FWT.
This property exists to make the operator matrix transparent.
Calling the object will handle odd-length inputs properly.
Expand Down Expand Up @@ -348,16 +354,14 @@ def _construct_analysis_matrices(
self.size_list.append(curr_length)

def __call__(self, input_signal: torch.Tensor) -> list[torch.Tensor]:
"""Compute the matrix fwt for the given input signal.
"""Compute the matrix FWT for the given input signal.
Matrix FWTs are used to avoid padding.
Args:
input_signal (torch.Tensor): Batched input data.
An example shape could be ``[batch_size, time]``.
Inputs can have any dimension.
input_signal (torch.Tensor): Input data to transform.
This transform affects the last axis by default.
Use the axis argument in the constructor to choose
Use the `axis` argument in the constructor to choose
another axis.
Returns:
Expand Down Expand Up @@ -527,7 +531,7 @@ def __init__(
defaults to -1 or the last axis.
orthogonalization: The method used to orthogonalize
boundary filters, see :data:`ptwt.constants.OrthogonalizeMethod`.
Defaults to 'qr'.
Defaults to ``qr``.
.. versionchanged:: 1.10
The argument `boundary` has been renamed to `orthogonalization`.
Expand Down Expand Up @@ -565,7 +569,11 @@ def sparse_ifwt_operator(self) -> torch.Tensor:
as a single matrix multiply.
Having concatenated the analysis coefficients,
torch.sparse.mm(sparse_ifwt_operator, coefficients.T)
.. code-block:: python
torch.sparse.mm(sparse_ifwt_operator, coefficients.T)
to computes a batched iFWT.
This functionality is mainly here to make the operator-matrix
Expand Down Expand Up @@ -638,19 +646,19 @@ def _construct_synthesis_matrices(
curr_length = curr_length // 2

def __call__(self, coefficients: WaveletCoeff1d) -> torch.Tensor:
"""Run the synthesis or inverse matrix fwt.
"""Run the synthesis or inverse matrix FWT.
Args:
coefficients: The coefficients produced by the forward transform
:data:`MatrixWavedec`. See :data:`ptwt.constants.WaveletCoeff1d`.
:class:`MatrixWavedec`. See :data:`ptwt.constants.WaveletCoeff1d`.
Returns:
The input signal reconstruction.
Raises:
ValueError: If the decomposition level is not a positive integer or if the
coefficients are not in the shape as it is returned from a
`MatrixWavedec` object.
:class:`MatrixWavedec` object.
"""
if not isinstance(coefficients, list):
coefficients = list(coefficients)
Expand Down
Loading

0 comments on commit 3fc692d

Please sign in to comment.