Skip to content

Commit

Permalink
Merge pull request #96 from v0lta/feature/packets-partial-refinement
Browse files Browse the repository at this point in the history
Add lazy init to packets for partial tree expansion
  • Loading branch information
v0lta authored Jun 27, 2024
2 parents 4e271a8 + bb79f6e commit fa7af3d
Show file tree
Hide file tree
Showing 3 changed files with 317 additions and 70 deletions.
2 changes: 1 addition & 1 deletion examples/deepfake_analysis/packet_plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def load_images(path: str) -> list:

if __name__ == "__main__":
freq_path = ptwt.WaveletPacket2D.get_freq_order(level=3)
frequency_path = ptwt.WaveletPacket2D.get_natural_order(level=3)
natural_path = ptwt.WaveletPacket2D.get_natural_order(level=3)
print("Loading ffhq images:")
ffhq_images = load_images("./ffhq_style_gan/source_data/A_ffhq")
print("processing ffhq")
Expand Down
185 changes: 137 additions & 48 deletions src/ptwt/packets.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,10 @@
from __future__ import annotations

import collections
from collections.abc import Sequence
from collections.abc import Callable, Iterable, Sequence
from functools import partial
from itertools import product
from typing import TYPE_CHECKING, Callable, Literal, Optional, Union, overload
from typing import TYPE_CHECKING, Literal, Optional, Union, overload

import numpy as np
import pywt
Expand Down Expand Up @@ -65,7 +65,9 @@ def __init__(
) -> None:
"""Create a wavelet packet decomposition object.
The decompositions will rely on padded fast wavelet transforms.
The packet tree is initialized lazily, i.e. a coefficient is only
calculated as it is retrieved. This allows for partial expansion
of the wavelet packet tree.
Args:
data (torch.Tensor, optional): The input data array of shape ``[time]``,
Expand Down Expand Up @@ -98,13 +100,10 @@ def __init__(
>>> w = scipy.signal.chirp(t, f0=1, f1=50, t1=10, method="linear")
>>> wp = ptwt.WaveletPacket(data=torch.from_numpy(w.astype(np.float32)),
>>> wavelet=pywt.Wavelet("db3"), mode="reflect")
>>> np_lst = []
>>> for node in wp.get_level(5):
>>> np_lst.append(wp[node])
>>> np_lst = [wp[node] for node in wp.get_level(5)]
>>> viz = np.stack(np_lst).squeeze()
>>> plt.imshow(np.abs(viz))
>>> plt.show()
"""
self.wavelet = _as_wavelet(wavelet)
self.mode = mode
Expand All @@ -113,30 +112,54 @@ def __init__(
self._matrix_waverec_dict: dict[int, MatrixWaverec] = {}
self.maxlevel: Optional[int] = None
self.axis = axis

self._filter_keys = {"a", "d"}

if data is not None:
self.transform(data, maxlevel)
else:
self.data = {}

def transform(
self, data: torch.Tensor, maxlevel: Optional[int] = None
self,
data: torch.Tensor,
maxlevel: Optional[int] = None,
) -> WaveletPacket:
"""Calculate the 1d wavelet packet transform for the input data.
"""Lazily calculate the 1d wavelet packet transform for the input data.
The packet tree is initialized lazily, i.e. a coefficient is only
calculated as it is retrieved. This allows for partial expansion
of the wavelet packet tree.
The transform function allows reusing the same object.
Args:
data (torch.Tensor): The input data array of shape ``[time]``
or ``[batch_size, time]``.
maxlevel (int, optional): The highest decomposition level to compute.
If None, the maximum level is determined from the input data shape.
Defaults to None.
Returns:
This wavelet packet object (to allow call chaining).
"""
self.data = {}
self.data = {"": data}
if maxlevel is None:
maxlevel = pywt.dwt_max_level(data.shape[self.axis], self.wavelet.dec_len)
self.maxlevel = maxlevel
self._recursive_dwt(data, level=0, path="")
return self

def initialize(self, keys: Iterable[str]) -> None:
"""Initialize the wavelet packet tree partially.
Args:
keys (Iterable[str]): An iterable yielding the keys of the
tree nodes to initialize.
"""
it = (self[key] for key in keys)
# exhaust iterator without storing all values
collections.deque(it, maxlen=0)

def reconstruct(self) -> WaveletPacket:
"""Recursively reconstruct the input starting from the leaf nodes.
Expand All @@ -153,18 +176,29 @@ def reconstruct(self) -> WaveletPacket:
>>> signal = np.random.randn(1, 16)
>>> ptwp = ptwt.WaveletPacket(torch.from_numpy(signal), "haar",
>>> mode="boundary", maxlevel=2)
>>> ptwp["aa"].data *= 0
>>> # initialize other leaf nodes
>>> ptwp.initialize(["ad", "da", "dd"])
>>> ptwp["aa"] = torch.zeros_like(ptwp["ad"])
>>> ptwp.reconstruct()
>>> print(ptwp[""])
Raises:
KeyError: if any leaf node data is not present.
"""
if self.maxlevel is None:
self.maxlevel = pywt.dwt_max_level(self[""].shape[-1], self.wavelet.dec_len)

for level in reversed(range(self.maxlevel)):
for node in self.get_level(level):
# check if any children is not available
# we need to check manually to avoid lazy init
for child in self._filter_keys:
if node + child not in self:
raise KeyError(f"Key {node + child} not found")

data_a = self[node + "a"]
data_b = self[node + "d"]
rec = self._get_waverec(data_a.shape[self.axis])([data_a, data_b])
data_d = self[node + "d"]
rec = self._get_waverec(data_a.shape[self.axis])([data_a, data_d])
if level > 0:
if rec.shape[self.axis] != self[node].shape[self.axis]:
assert (
Expand Down Expand Up @@ -242,15 +276,11 @@ def _get_graycode_order(level: int, x: str = "a", y: str = "d") -> list[str]:
else:
return graycode_order

def _recursive_dwt(self, data: torch.Tensor, level: int, path: str) -> None:
if self.maxlevel is None:
raise AssertionError

self.data[path] = data
if level < self.maxlevel:
res_lo, res_hi = self._get_wavedec(data.shape[self.axis])(data)
self._recursive_dwt(res_lo, level + 1, path + "a")
self._recursive_dwt(res_hi, level + 1, path + "d")
def _expand_node(self, path: str) -> None:
data = self[path]
res_lo, res_hi = self._get_wavedec(data.shape[self.axis])(data)
self.data[path + "a"] = res_lo
self.data[path + "d"] = res_hi

def __getitem__(self, key: str) -> torch.Tensor:
"""Access the coefficients in the wavelet packets tree.
Expand All @@ -265,7 +295,8 @@ def __getitem__(self, key: str) -> torch.Tensor:
Raises:
ValueError: If the wavelet packet tree is not initialized.
KeyError: If no wavelet coefficients are indexed by the specified key.
KeyError: If no wavelet coefficients are indexed by the specified key
and a lazy initialization fails.
"""
if self.maxlevel is None:
raise ValueError(
Expand All @@ -278,6 +309,20 @@ def __getitem__(self, key: str) -> torch.Tensor:
"cannot be accessed! This wavelet packet tree is initialized with "
f"maximum level {self.maxlevel}."
)
elif key not in self:
if key == "":
raise ValueError(
"The requested root of the packet tree cannot be accessed! "
"The wavelet packet tree is not properly initialized. "
"Run `transform` before accessing tree values."
)
elif key[-1] not in self._filter_keys:
raise ValueError(
f"Invalid key '{key}'. All chars in the key must be of the "
f"set {self._filter_keys}."
)
# calculate data from parent
self._expand_node(key[:-1])
return super().__getitem__(key)


Expand All @@ -300,6 +345,10 @@ def __init__(
) -> None:
"""Create a 2D-Wavelet packet tree.
The packet tree is initialized lazily, i.e. a coefficient is only
calculated as it is retrieved. This allows for partial expansion
of the wavelet packet tree.
Args:
data (torch.tensor, optional): The input data tensor.
For example of shape ``[batch_size, height, width]`` or
Expand All @@ -324,7 +373,6 @@ def __init__(
Only used if `mode` equals 'boundary'. Defaults to 'qr'.
separable (bool): If true, a separable transform is performed,
i.e. each image axis is transformed separately. Defaults to False.
"""
self.wavelet = _as_wavelet(wavelet)
self.mode = mode
Expand All @@ -333,6 +381,7 @@ def __init__(
self.matrix_wavedec2_dict: dict[tuple[int, ...], MatrixWavedec2] = {}
self.matrix_waverec2_dict: dict[tuple[int, ...], MatrixWaverec2] = {}
self.axes = axes
self._filter_keys = {"a", "h", "v", "d"}

self.maxlevel: Optional[int] = None
if data is not None:
Expand All @@ -341,42 +390,70 @@ def __init__(
self.data = {}

def transform(
self, data: torch.Tensor, maxlevel: Optional[int] = None
self,
data: torch.Tensor,
maxlevel: Optional[int] = None,
) -> WaveletPacket2D:
"""Calculate the 2d wavelet packet transform for the input data.
"""Lazily calculate the 2d wavelet packet transform for the input data.
The packet tree is initialized lazily, i.e. a coefficient is only
calculated as it is retrieved. This allows for partial expansion
of the wavelet packet tree.
The transform function allows reusing the same object.
The transform function allows reusing the same object.
Args:
data (torch.tensor): The input data tensor
of shape ``[batch_size, height, width]``.
maxlevel (int, optional): The highest decomposition level to compute.
If None, the maximum level is determined from the input data shape.
Defaults to None.
Returns:
This wavelet packet object (to allow call chaining).
"""
self.data = {}
self.data = {"": data}
if maxlevel is None:
min_transform_size = min(_swap_axes(data, self.axes).shape[-2:])
maxlevel = pywt.dwt_max_level(min_transform_size, self.wavelet.dec_len)
self.maxlevel = maxlevel

self._recursive_dwt2d(data, level=0, path="")
return self

def initialize(self, keys: Iterable[str]) -> None:
"""Initialize the wavelet packet tree partially.
Args:
keys (Iterable[str]): An iterable yielding the keys of the
tree nodes to initialize.
"""
it = (self[key] for key in keys)
# exhaust iterator without storing all values
collections.deque(it, maxlen=0)

def reconstruct(self) -> WaveletPacket2D:
"""Recursively reconstruct the input starting from the leaf nodes.
Note:
Only changes to leaf node data impact the results,
since changes in all other nodes will be replaced with
a reconstruction from the leaves.
Raises:
KeyError: if any leaf node data is not present.
"""
if self.maxlevel is None:
min_transform_size = min(_swap_axes(self[""], self.axes).shape[-2:])
self.maxlevel = pywt.dwt_max_level(min_transform_size, self.wavelet.dec_len)

for level in reversed(range(self.maxlevel)):
for node in WaveletPacket2D.get_natural_order(level):
# check if any children is not available
# we need to check manually to avoid lazy init
for child in self._filter_keys:
if node + child not in self:
raise KeyError(f"Key {node + child} not found")

data_a = self[node + "a"]
data_h = self[node + "h"]
data_v = self[node + "v"]
Expand All @@ -402,6 +479,19 @@ def reconstruct(self) -> WaveletPacket2D:
self[node] = rec
return self

def _expand_node(self, path: str) -> None:
data = self[path]
transform_size = _swap_axes(data, self.axes).shape[-2:]
result = self._get_wavedec(transform_size)(data)

# assert for type checking
assert len(result) == 2
result_a, (result_h, result_v, result_d) = result
self.data[path + "a"] = result_a
self.data[path + "h"] = result_h
self.data[path + "v"] = result_v
self.data[path + "d"] = result_d

def _get_wavedec(self, shape: tuple[int, ...]) -> Callable[
[torch.Tensor],
WaveletCoeff2d,
Expand Down Expand Up @@ -483,23 +573,6 @@ def _fsdict_func(coeffs: WaveletCoeff2d) -> torch.Tensor:

return _fsdict_func

def _recursive_dwt2d(self, data: torch.Tensor, level: int, path: str) -> None:
if self.maxlevel is None:
raise AssertionError

self.data[path] = data
if level < self.maxlevel:
transform_size = _swap_axes(data, self.axes).shape[-2:]
result = self._get_wavedec(transform_size)(data)

# assert for type checking
assert len(result) == 2
result_a, (result_h, result_v, result_d) = result
self._recursive_dwt2d(result_a, level + 1, path + "a")
self._recursive_dwt2d(result_h, level + 1, path + "h")
self._recursive_dwt2d(result_v, level + 1, path + "v")
self._recursive_dwt2d(result_d, level + 1, path + "d")

def __getitem__(self, key: str) -> torch.Tensor:
"""Access the coefficients in the wavelet packets tree.
Expand All @@ -516,7 +589,8 @@ def __getitem__(self, key: str) -> torch.Tensor:
Raises:
ValueError: If the wavelet packet tree is not initialized.
KeyError: If no wavelet coefficients are indexed by the specified key.
KeyError: If no wavelet coefficients are indexed by the specified key
and a lazy initialization fails.
"""
if self.maxlevel is None:
raise ValueError(
Expand All @@ -529,6 +603,21 @@ def __getitem__(self, key: str) -> torch.Tensor:
"cannot be accessed! This wavelet packet tree is initialized with "
f"maximum level {self.maxlevel}."
)
elif key not in self:
if key == "":
raise ValueError(
"The requested root of the packet tree cannot be accessed! "
"The wavelet packet tree is not properly initialized. "
"Run `transform` before accessing tree values."
)
elif key[-1] not in self._filter_keys:
raise ValueError(
f"Invalid key '{key}'. All chars in the key must be of the "
f"set {self._filter_keys}."
)
# calculate data from parent
self._expand_node(key[:-1])

return super().__getitem__(key)

@overload
Expand Down
Loading

0 comments on commit fa7af3d

Please sign in to comment.