Skip to content

Commit

Permalink
Merge pull request #94 from v0lta/fix/packet-axis
Browse files Browse the repository at this point in the history
Fix non-default axis in packets
  • Loading branch information
v0lta authored Jun 26, 2024
2 parents ef4f80a + 48ad03b commit 848d0f7
Show file tree
Hide file tree
Showing 2 changed files with 113 additions and 108 deletions.
52 changes: 26 additions & 26 deletions src/ptwt/packets.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
import pywt
import torch

from ._util import Wavelet, _as_wavelet
from ._util import Wavelet, _as_wavelet, _swap_axes, _undo_swap_axes
from .constants import (
ExtendedBoundaryMode,
OrthogonalizeMethod,
Expand Down Expand Up @@ -113,9 +113,6 @@ def __init__(
self.maxlevel: Optional[int] = None
self.axis = axis
if data is not None:
if len(data.shape) == 1:
# add a batch dimension.
data = data.unsqueeze(0)
self.transform(data, maxlevel)
else:
self.data = {}
Expand All @@ -134,7 +131,7 @@ def transform(
"""
self.data = {}
if maxlevel is None:
maxlevel = pywt.dwt_max_level(data.shape[-1], self.wavelet.dec_len)
maxlevel = pywt.dwt_max_level(data.shape[self.axis], self.wavelet.dec_len)
self.maxlevel = maxlevel
self._recursive_dwt(data, level=0, path="")
return self
Expand Down Expand Up @@ -166,13 +163,15 @@ def reconstruct(self) -> WaveletPacket:
for node in self.get_level(level):
data_a = self[node + "a"]
data_b = self[node + "d"]
rec = self._get_waverec(data_a.shape[-1])([data_a, data_b])
rec = self._get_waverec(data_a.shape[self.axis])([data_a, data_b])
if level > 0:
if rec.shape[-1] != self[node].shape[-1]:
if rec.shape[self.axis] != self[node].shape[self.axis]:
assert (
rec.shape[-1] == self[node].shape[-1] + 1
rec.shape[self.axis] == self[node].shape[self.axis] + 1
), "padding error, please open an issue on github"
rec = rec[..., :-1]
rec = rec.swapaxes(self.axis, -1)[..., :-1].swapaxes(
-1, self.axis
)
self[node] = rec
return self

Expand Down Expand Up @@ -227,12 +226,12 @@ def _get_graycode_order(self, level: int, x: str = "a", y: str = "d") -> list[st
return graycode_order

def _recursive_dwt(self, data: torch.Tensor, level: int, path: str) -> None:
if not self.maxlevel:
if self.maxlevel is None:
raise AssertionError

self.data[path] = data
if level < self.maxlevel:
res_lo, res_hi = self._get_wavedec(data.shape[-1])(data)
res_lo, res_hi = self._get_wavedec(data.shape[self.axis])(data)
self._recursive_dwt(res_lo, level + 1, path + "a")
self._recursive_dwt(res_hi, level + 1, path + "d")

Expand Down Expand Up @@ -340,13 +339,10 @@ def transform(
"""
self.data = {}
if maxlevel is None:
maxlevel = pywt.dwt_max_level(min(data.shape[-2:]), self.wavelet.dec_len)
min_transform_size = min(_swap_axes(data, self.axes).shape[-2:])
maxlevel = pywt.dwt_max_level(min_transform_size, self.wavelet.dec_len)
self.maxlevel = maxlevel

if data.dim() == 2:
# add batch dim to unbatched input
data = data.unsqueeze(0)

self._recursive_dwt2d(data, level=0, path="")
return self

Expand All @@ -359,30 +355,33 @@ def reconstruct(self) -> WaveletPacket2D:
a reconstruction from the leaves.
"""
if self.maxlevel is None:
self.maxlevel = pywt.dwt_max_level(
min(self[""].shape[-2:]), self.wavelet.dec_len
)
min_transform_size = min(_swap_axes(self[""], self.axes).shape[-2:])
self.maxlevel = pywt.dwt_max_level(min_transform_size, self.wavelet.dec_len)

for level in reversed(range(self.maxlevel)):
for node in WaveletPacket2D.get_natural_order(level):
data_a = self[node + "a"]
data_h = self[node + "h"]
data_v = self[node + "v"]
data_d = self[node + "d"]
rec = self._get_waverec(data_a.shape[-2:])(
transform_size = _swap_axes(data_a, self.axes).shape[-2:]
rec = self._get_waverec(transform_size)(
(data_a, WaveletDetailTuple2d(data_h, data_v, data_d))
)
if level > 0:
if rec.shape[-1] != self[node].shape[-1]:
rec = _swap_axes(rec, self.axes)
swapped_node = _swap_axes(self[node], self.axes)
if rec.shape[-1] != swapped_node.shape[-1]:
assert (
rec.shape[-1] == self[node].shape[-1] + 1
rec.shape[-1] == swapped_node.shape[-1] + 1
), "padding error, please open an issue on GitHub"
rec = rec[..., :-1]
if rec.shape[-2] != self[node].shape[-2]:
if rec.shape[-2] != swapped_node.shape[-2]:
assert (
rec.shape[-2] == self[node].shape[-2] + 1
rec.shape[-2] == swapped_node.shape[-2] + 1
), "padding error, please open an issue on GitHub"
rec = rec[..., :-1, :]
rec = _undo_swap_axes(rec, self.axes)
self[node] = rec
return self

Expand Down Expand Up @@ -468,12 +467,13 @@ def _fsdict_func(coeffs: WaveletCoeff2d) -> torch.Tensor:
return _fsdict_func

def _recursive_dwt2d(self, data: torch.Tensor, level: int, path: str) -> None:
if not self.maxlevel:
if self.maxlevel is None:
raise AssertionError

self.data[path] = data
if level < self.maxlevel:
result = self._get_wavedec(data.shape[-2:])(data)
transform_size = _swap_axes(data, self.axes).shape[-2:]
result = self._get_wavedec(transform_size)(data)

# assert for type checking
assert len(result) == 2
Expand Down
Loading

0 comments on commit 848d0f7

Please sign in to comment.