Skip to content

Commit

Permalink
Merge pull request #93 from v0lta/fix/keep-ndims-Nd
Browse files Browse the repository at this point in the history
Make preprocessing and postprocessing consistent accross transforms
  • Loading branch information
v0lta authored Jul 1, 2024
2 parents b87482f + 9981521 commit 85b898a
Show file tree
Hide file tree
Showing 9 changed files with 626 additions and 685 deletions.
407 changes: 393 additions & 14 deletions src/ptwt/_util.py

Large diffs are not rendered by default.

126 changes: 14 additions & 112 deletions src/ptwt/conv_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,13 @@
from ._util import (
Wavelet,
_as_wavelet,
_fold_axes,
_check_same_device_dtype,
_get_len,
_is_dtype_supported,
_pad_symmetric,
_unfold_axes,
_postprocess_coeffs,
_postprocess_tensor,
_preprocess_coeffs,
_preprocess_tensor,
)
from .constants import BoundaryMode, WaveletCoeff2d

Expand Down Expand Up @@ -211,63 +213,6 @@ def _adjust_padding_at_reconstruction(
return pad_end, pad_start


def _preprocess_tensor_dec1d(
data: torch.Tensor,
) -> tuple[torch.Tensor, list[int]]:
"""Preprocess input tensor dimensions.
Args:
data (torch.Tensor): An input tensor of any shape.
Returns:
A tuple (data, ds) where data is a data tensor of shape
[new_batch, 1, to_process] and ds contains the original shape.
"""
ds = list(data.shape)
if len(ds) == 1:
# assume time series
data = data.unsqueeze(0).unsqueeze(0)
elif len(ds) == 2:
# assume batched time series
data = data.unsqueeze(1)
else:
data, ds = _fold_axes(data, 1)
data = data.unsqueeze(1)
return data, ds


def _postprocess_result_list_dec1d(
result_list: list[torch.Tensor], ds: list[int], axis: int
) -> list[torch.Tensor]:
if len(ds) == 1:
result_list = [r_el.squeeze(0) for r_el in result_list]
elif len(ds) > 2:
# Unfold axes for the wavelets
result_list = [_unfold_axes(fres, ds, 1) for fres in result_list]
else:
result_list = result_list

if axis != -1:
result_list = [coeff.swapaxes(axis, -1) for coeff in result_list]

return result_list


def _preprocess_result_list_rec1d(
result_lst: Sequence[torch.Tensor],
) -> tuple[Sequence[torch.Tensor], list[int]]:
# Fold axes for the wavelets
ds = list(result_lst[0].shape)
fold_coeffs: Sequence[torch.Tensor]
if len(ds) == 1:
fold_coeffs = [uf_coeff.unsqueeze(0) for uf_coeff in result_lst]
elif len(ds) > 2:
fold_coeffs = [_fold_axes(uf_coeff, 1)[0] for uf_coeff in result_lst]
else:
fold_coeffs = result_lst
return fold_coeffs, ds


def wavedec(
data: torch.Tensor,
wavelet: Union[Wavelet, str],
Expand Down Expand Up @@ -315,10 +260,6 @@ def wavedec(
containing the wavelet coefficients. A denotes
approximation and D detail coefficients.
Raises:
ValueError: If the dtype of the input data tensor is unsupported or
if more than one axis is provided.
Example:
>>> import torch
>>> import ptwt, pywt
Expand All @@ -330,16 +271,7 @@ def wavedec(
>>> ptwt.wavedec(data_torch, pywt.Wavelet('haar'),
>>> mode='zero', level=2)
"""
if axis != -1:
if isinstance(axis, int):
data = data.swapaxes(axis, -1)
else:
raise ValueError("wavedec transforms a single axis only.")

if not _is_dtype_supported(data.dtype):
raise ValueError(f"Input dtype {data.dtype} not supported")

data, ds = _preprocess_tensor_dec1d(data)
data, ds = _preprocess_tensor(data, ndim=1, axes=axis)

dec_lo, dec_hi, _, _ = _get_filter_tensors(
wavelet, flip=True, device=data.device, dtype=data.dtype
Expand All @@ -360,9 +292,7 @@ def wavedec(
result_list.append(res_lo.squeeze(1))
result_list.reverse()

result_list = _postprocess_result_list_dec1d(result_list, ds, axis)

return result_list
return _postprocess_coeffs(result_list, ndim=1, ds=ds, axes=axis)


def waverec(
Expand All @@ -381,11 +311,6 @@ def waverec(
Returns:
The reconstructed signal tensor.
Raises:
ValueError: If the dtype of the coeffs tensor is unsupported or if the
coefficients have incompatible shapes, dtypes or devices or if
more than one axis is provided.
Example:
>>> import torch
>>> import ptwt, pywt
Expand All @@ -399,29 +324,11 @@ def waverec(
>>> pywt.Wavelet('haar'))
"""
torch_device = coeffs[0].device
torch_dtype = coeffs[0].dtype
if not _is_dtype_supported(torch_dtype):
raise ValueError(f"Input dtype {torch_dtype} not supported")

for coeff in coeffs[1:]:
if torch_device != coeff.device:
raise ValueError("coefficients must be on the same device")
elif torch_dtype != coeff.dtype:
raise ValueError("coefficients must have the same dtype")

if axis != -1:
swap = []
if isinstance(axis, int):
for coeff in coeffs:
swap.append(coeff.swapaxes(axis, -1))
coeffs = swap
else:
raise ValueError("waverec transforms a single axis only.")

# fold channels, if necessary.
ds = list(coeffs[0].shape)
coeffs, ds = _preprocess_result_list_rec1d(coeffs)
# fold channels and swap axis, if necessary.
if not isinstance(coeffs, list):
coeffs = list(coeffs)
coeffs, ds = _preprocess_coeffs(coeffs, ndim=1, axes=axis)
torch_device, torch_dtype = _check_same_device_dtype(coeffs)

_, _, rec_lo, rec_hi = _get_filter_tensors(
wavelet, flip=False, device=torch_device, dtype=torch_dtype
Expand All @@ -446,12 +353,7 @@ def waverec(
if padr > 0:
res_lo = res_lo[..., :-padr]

if len(ds) == 1:
res_lo = res_lo.squeeze(0)
elif len(ds) > 2:
res_lo = _unfold_axes(res_lo, ds, 1)

if axis != -1:
res_lo = res_lo.swapaxes(axis, -1)
# undo folding and swapping
res_lo = _postprocess_tensor(res_lo, ndim=1, ds=ds, axes=axis)

return res_lo
103 changes: 12 additions & 91 deletions src/ptwt/conv_transform_2.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,26 +6,21 @@

from __future__ import annotations

from functools import partial
from typing import Optional, Union

import pywt
import torch

from ._util import (
Wavelet,
_as_wavelet,
_check_axes_argument,
_check_if_tensor,
_fold_axes,
_check_same_device_dtype,
_get_len,
_is_dtype_supported,
_map_result,
_outer,
_pad_symmetric,
_swap_axes,
_undo_swap_axes,
_unfold_axes,
_postprocess_coeffs,
_postprocess_tensor,
_preprocess_coeffs,
_preprocess_tensor,
)
from .constants import BoundaryMode, WaveletCoeff2d, WaveletDetailTuple2d
from .conv_transform import (
Expand Down Expand Up @@ -107,32 +102,6 @@ def _fwt_pad2(
return data_pad


def _waverec2d_fold_channels_2d_list(
coeffs: WaveletCoeff2d,
) -> tuple[WaveletCoeff2d, list[int]]:
# fold the input coefficients for processing conv2d_transpose.
ds = list(_check_if_tensor(coeffs[0]).shape)
return _map_result(coeffs, lambda t: _fold_axes(t, 2)[0]), ds


def _preprocess_tensor_dec2d(
data: torch.Tensor,
) -> tuple[torch.Tensor, Union[list[int], None]]:
# Preprocess multidimensional input.
ds = None
if len(data.shape) == 2:
data = data.unsqueeze(0).unsqueeze(0)
elif len(data.shape) == 3:
# add a channel dimension for torch.
data = data.unsqueeze(1)
elif len(data.shape) >= 4:
data, ds = _fold_axes(data, 2)
data = data.unsqueeze(1)
elif len(data.shape) == 1:
raise ValueError("More than one input dimension required.")
return data, ds


def wavedec2(
data: torch.Tensor,
wavelet: Union[Wavelet, str],
Expand Down Expand Up @@ -183,11 +152,6 @@ def wavedec2(
A tuple containing the wavelet coefficients in pywt order,
see :data:`ptwt.constants.WaveletCoeff2d`.
Raises:
ValueError: If the dimensionality or the dtype of the input data tensor
is unsupported or if the provided ``axes``
input has a length other than two.
Example:
>>> import torch
>>> import ptwt, pywt
Expand All @@ -200,17 +164,7 @@ def wavedec2(
>>> level=2, mode="zero")
"""
if not _is_dtype_supported(data.dtype):
raise ValueError(f"Input dtype {data.dtype} not supported")

if tuple(axes) != (-2, -1):
if len(axes) != 2:
raise ValueError("2D transforms work with two axes.")
else:
data = _swap_axes(data, list(axes))

wavelet = _as_wavelet(wavelet)
data, ds = _preprocess_tensor_dec2d(data)
data, ds = _preprocess_tensor(data, ndim=2, axes=axes)
dec_lo, dec_hi, _, _ = _get_filter_tensors(
wavelet, flip=True, device=data.device, dtype=data.dtype
)
Expand All @@ -234,13 +188,7 @@ def wavedec2(
res_ll = res_ll.squeeze(1)
result: WaveletCoeff2d = res_ll, *result_lst

if ds:
_unfold_axes2 = partial(_unfold_axes, ds=ds, keep_no=2)
result = _map_result(result, _unfold_axes2)

if axes != (-2, -1):
undo_swap_fn = partial(_undo_swap_axes, axes=axes)
result = _map_result(result, undo_swap_fn)
result = _postprocess_coeffs(result, ndim=2, ds=ds, axes=axes)

return result

Expand Down Expand Up @@ -286,35 +234,16 @@ def waverec2(
>>> reconstruction = ptwt.waverec2(coefficients, pywt.Wavelet("haar"))
"""
if tuple(axes) != (-2, -1):
if len(axes) != 2:
raise ValueError("2D transforms work with two axes.")
else:
_check_axes_argument(list(axes))
swap_fn = partial(_swap_axes, axes=list(axes))
coeffs = _map_result(coeffs, swap_fn)

ds = None
wavelet = _as_wavelet(wavelet)

res_ll = _check_if_tensor(coeffs[0])
torch_device = res_ll.device
torch_dtype = res_ll.dtype

if res_ll.dim() >= 4:
# avoid the channel sum, fold the channels into batches.
coeffs, ds = _waverec2d_fold_channels_2d_list(coeffs)
res_ll = _check_if_tensor(coeffs[0])

if not _is_dtype_supported(torch_dtype):
raise ValueError(f"Input dtype {torch_dtype} not supported")
coeffs, ds = _preprocess_coeffs(coeffs, ndim=2, axes=axes)
torch_device, torch_dtype = _check_same_device_dtype(coeffs)

_, _, rec_lo, rec_hi = _get_filter_tensors(
wavelet, flip=False, device=torch_device, dtype=torch_dtype
)
filt_len = rec_lo.shape[-1]
rec_filt = _construct_2d_filt(lo=rec_lo, hi=rec_hi)

res_ll = coeffs[0]
for c_pos, coeff_tuple in enumerate(coeffs[1:]):
if not isinstance(coeff_tuple, tuple) or len(coeff_tuple) != 3:
raise ValueError(
Expand All @@ -325,11 +254,7 @@ def waverec2(

curr_shape = res_ll.shape
for coeff in coeff_tuple:
if torch_device != coeff.device:
raise ValueError("coefficients must be on the same device")
elif torch_dtype != coeff.dtype:
raise ValueError("coefficients must have the same dtype")
elif coeff.shape != curr_shape:
if coeff.shape != curr_shape:
raise ValueError(
"All coefficients on each level must have the same shape"
)
Expand Down Expand Up @@ -362,10 +287,6 @@ def waverec2(
if padr > 0:
res_ll = res_ll[..., :-padr]

if ds:
res_ll = _unfold_axes(res_ll, list(ds), 2)

if axes != (-2, -1):
res_ll = _undo_swap_axes(res_ll, list(axes))
res_ll = _postprocess_tensor(res_ll, ndim=2, ds=ds, axes=axes)

return res_ll
Loading

0 comments on commit 85b898a

Please sign in to comment.