From ef9ba5a6f5b82cb86cdec38354e1772d28931bc8 Mon Sep 17 00:00:00 2001 From: Felix Blanke Date: Tue, 25 Jun 2024 13:23:02 +0200 Subject: [PATCH 01/21] Change packet logic to use expand func --- src/ptwt/packets.py | 59 +++++++++++++++++++++++++++++---------------- 1 file changed, 38 insertions(+), 21 deletions(-) diff --git a/src/ptwt/packets.py b/src/ptwt/packets.py index 879f5eff..9bad134b 100644 --- a/src/ptwt/packets.py +++ b/src/ptwt/packets.py @@ -132,11 +132,11 @@ def transform( If None, the maximum level is determined from the input data shape. Defaults to None. """ - self.data = {} + self.data = {"": data} if maxlevel is None: maxlevel = pywt.dwt_max_level(data.shape[-1], self.wavelet.dec_len) self.maxlevel = maxlevel - self._recursive_dwt(data, level=0, path="") + self._recursive_dwt(path="") return self def reconstruct(self) -> WaveletPacket: @@ -226,15 +226,24 @@ def _get_graycode_order(self, level: int, x: str = "a", y: str = "d") -> list[st else: return graycode_order - def _recursive_dwt(self, data: torch.Tensor, level: int, path: str) -> None: + def _expand_node(self, path: str) -> None: + data = self[path] + res_lo, res_hi = self._get_wavedec(data.shape[-1])(data) + self.data[path + "a"] = res_lo + self.data[path + "d"] = res_hi + + def _recursive_dwt(self, path: str) -> None: if not self.maxlevel: raise AssertionError - self.data[path] = data - if level < self.maxlevel: - res_lo, res_hi = self._get_wavedec(data.shape[-1])(data) - self._recursive_dwt(res_lo, level + 1, path + "a") - self._recursive_dwt(res_hi, level + 1, path + "d") + if len(path) >= self.maxlevel: + # nothing to expand + return + + self._expand_node(path) + + for child in ["a", "d"]: + self._recursive_dwt(path + child) def __getitem__(self, key: str) -> torch.Tensor: """Access the coefficients in the wavelet packets tree. @@ -338,7 +347,7 @@ def transform( If None, the maximum level is determined from the input data shape. Defaults to None. """ - self.data = {} + self.data = {"": data} if maxlevel is None: maxlevel = pywt.dwt_max_level(min(data.shape[-2:]), self.wavelet.dec_len) self.maxlevel = maxlevel @@ -347,7 +356,7 @@ def transform( # add batch dim to unbatched input data = data.unsqueeze(0) - self._recursive_dwt2d(data, level=0, path="") + self._recursive_dwt2d(path="") return self def reconstruct(self) -> WaveletPacket2D: @@ -386,6 +395,18 @@ def reconstruct(self) -> WaveletPacket2D: self[node] = rec return self + def _expand_node(self, path: str) -> None: + data = self[path] + result = self._get_wavedec(data.shape[-2:])(data) + + # assert for type checking + assert len(result) == 2 + result_a, (result_h, result_v, result_d) = result + self.data[path + "a"] = result_a + self.data[path + "h"] = result_h + self.data[path + "v"] = result_v + self.data[path + "d"] = result_d + def _get_wavedec(self, shape: tuple[int, ...]) -> Callable[ [torch.Tensor], WaveletCoeff2d, @@ -467,21 +488,17 @@ def _fsdict_func(coeffs: WaveletCoeff2d) -> torch.Tensor: return _fsdict_func - def _recursive_dwt2d(self, data: torch.Tensor, level: int, path: str) -> None: + def _recursive_dwt2d(self, path: str) -> None: if not self.maxlevel: raise AssertionError - self.data[path] = data - if level < self.maxlevel: - result = self._get_wavedec(data.shape[-2:])(data) + if len(path) >= self.maxlevel: + # nothing to expand + return - # assert for type checking - assert len(result) == 2 - result_a, (result_h, result_v, result_d) = result - self._recursive_dwt2d(result_a, level + 1, path + "a") - self._recursive_dwt2d(result_h, level + 1, path + "h") - self._recursive_dwt2d(result_v, level + 1, path + "v") - self._recursive_dwt2d(result_d, level + 1, path + "d") + self._expand_node(path) + for child in ["a", "h", "v", "d"]: + self._recursive_dwt2d(path + child) def __getitem__(self, key: str) -> torch.Tensor: """Access the coefficients in the wavelet packets tree. From f67b28fd6ab70eaafd57eb6bf88a14bb452e6353 Mon Sep 17 00:00:00 2001 From: Felix Blanke Date: Tue, 25 Jun 2024 13:25:07 +0200 Subject: [PATCH 02/21] Calculate non-existing coeffs --- src/ptwt/packets.py | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/src/ptwt/packets.py b/src/ptwt/packets.py index 9bad134b..88cd82bc 100644 --- a/src/ptwt/packets.py +++ b/src/ptwt/packets.py @@ -271,6 +271,15 @@ def __getitem__(self, key: str) -> torch.Tensor: "cannot be accessed! This wavelet packet tree is initialized with " f"maximum level {self.maxlevel}." ) + elif key not in self: + if key == "": + raise ValueError( + "The requested root of the packet tree cannot be accessed! " + "The wavelet packet tree is not properly initialized. " + "Run `transform` before accessing tree values." + ) + # calculate data from parent + self._expand_node(key[:-1]) return super().__getitem__(key) @@ -529,6 +538,16 @@ def __getitem__(self, key: str) -> torch.Tensor: "cannot be accessed! This wavelet packet tree is initialized with " f"maximum level {self.maxlevel}." ) + elif key not in self: + if key == "": + raise ValueError( + "The requested root of the packet tree cannot be accessed! " + "The wavelet packet tree is not properly initialized. " + "Run `transform` before accessing tree values." + ) + # calculate data from parent + self._expand_node(key[:-1]) + return super().__getitem__(key) @staticmethod From b78434af73cccf09f9347ac22f0bea72dc73818f Mon Sep 17 00:00:00 2001 From: Felix Blanke Date: Tue, 25 Jun 2024 13:27:11 +0200 Subject: [PATCH 03/21] Add arg for lazy packet initialization --- src/ptwt/packets.py | 22 ++++++++++++++++------ 1 file changed, 16 insertions(+), 6 deletions(-) diff --git a/src/ptwt/packets.py b/src/ptwt/packets.py index 88cd82bc..614a9e41 100644 --- a/src/ptwt/packets.py +++ b/src/ptwt/packets.py @@ -61,6 +61,7 @@ def __init__( maxlevel: Optional[int] = None, axis: int = -1, boundary_orthogonalization: OrthogonalizeMethod = "qr", + lazy_init: bool = False, ) -> None: """Create a wavelet packet decomposition object. @@ -116,12 +117,15 @@ def __init__( if len(data.shape) == 1: # add a batch dimension. data = data.unsqueeze(0) - self.transform(data, maxlevel) + self.transform(data, maxlevel, lazy_init=lazy_init) else: self.data = {} def transform( - self, data: torch.Tensor, maxlevel: Optional[int] = None + self, + data: torch.Tensor, + maxlevel: Optional[int] = None, + lazy_init: bool = False, ) -> WaveletPacket: """Calculate the 1d wavelet packet transform for the input data. @@ -136,7 +140,8 @@ def transform( if maxlevel is None: maxlevel = pywt.dwt_max_level(data.shape[-1], self.wavelet.dec_len) self.maxlevel = maxlevel - self._recursive_dwt(path="") + if not lazy_init: + self._recursive_dwt(path="") return self def reconstruct(self) -> WaveletPacket: @@ -299,6 +304,7 @@ def __init__( axes: tuple[int, int] = (-2, -1), boundary_orthogonalization: OrthogonalizeMethod = "qr", separable: bool = False, + lazy_init: bool = False, ) -> None: """Create a 2D-Wavelet packet tree. @@ -338,12 +344,15 @@ def __init__( self.maxlevel: Optional[int] = None if data is not None: - self.transform(data, maxlevel) + self.transform(data, maxlevel, lazy_init=lazy_init) else: self.data = {} def transform( - self, data: torch.Tensor, maxlevel: Optional[int] = None + self, + data: torch.Tensor, + maxlevel: Optional[int] = None, + lazy_init: bool = False, ) -> WaveletPacket2D: """Calculate the 2d wavelet packet transform for the input data. @@ -365,7 +374,8 @@ def transform( # add batch dim to unbatched input data = data.unsqueeze(0) - self._recursive_dwt2d(path="") + if not lazy_init: + self._recursive_dwt2d(path="") return self def reconstruct(self) -> WaveletPacket2D: From d98244d3db8a23b11ca6a2a350f6653e54f78ef9 Mon Sep 17 00:00:00 2001 From: Felix Blanke Date: Tue, 25 Jun 2024 13:32:28 +0200 Subject: [PATCH 04/21] Add arg docstrings --- src/ptwt/packets.py | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/src/ptwt/packets.py b/src/ptwt/packets.py index 614a9e41..23dd7137 100644 --- a/src/ptwt/packets.py +++ b/src/ptwt/packets.py @@ -88,6 +88,10 @@ def __init__( to use in the sparse matrix backend, see :data:`ptwt.constants.OrthogonalizeMethod`. Only used if `mode` equals 'boundary'. Defaults to 'qr'. + lazy_init (bool): Value is passed on to :func:`transform`. + If True, the packet tree is initialized lazily. This + allows for partial expansion of the wavelet packet tree. + Defaults to False. Example: >>> import torch, pywt, ptwt @@ -135,6 +139,10 @@ def transform( maxlevel (int, optional): The highest decomposition level to compute. If None, the maximum level is determined from the input data shape. Defaults to None. + lazy_init (bool): If True, the packet tree is initialized lazily. + This allows for partial expansion of the wavelet packet tree. + Otherwise, all packet coefficients up to the decomposition level + `maxlevel` are computed. Defaults to False. """ self.data = {"": data} if maxlevel is None: @@ -332,6 +340,10 @@ def __init__( Only used if `mode` equals 'boundary'. Defaults to 'qr'. separable (bool): If true, a separable transform is performed, i.e. each image axis is transformed separately. Defaults to False. + lazy_init (bool): Value is passed on to :func:`transform`. + If True, the packet tree is initialized lazily. This + allows for partial expansion of the wavelet packet tree. + Defaults to False. """ self.wavelet = _as_wavelet(wavelet) @@ -364,6 +376,10 @@ def transform( maxlevel (int, optional): The highest decomposition level to compute. If None, the maximum level is determined from the input data shape. Defaults to None. + lazy_init (bool): If True, the packet tree is initialized lazily. + This allows for partial expansion of the wavelet packet tree. + Otherwise, all packet coefficients up to the decomposition level + `maxlevel` are computed. Defaults to False. """ self.data = {"": data} if maxlevel is None: From 777853669d1b5d66b1aea8719fce7eca4b5ec52e Mon Sep 17 00:00:00 2001 From: Felix Blanke Date: Tue, 25 Jun 2024 13:33:41 +0200 Subject: [PATCH 05/21] Add return docstr and change import --- src/ptwt/packets.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/src/ptwt/packets.py b/src/ptwt/packets.py index 23dd7137..1c9e96fe 100644 --- a/src/ptwt/packets.py +++ b/src/ptwt/packets.py @@ -3,10 +3,10 @@ from __future__ import annotations import collections -from collections.abc import Sequence +from collections.abc import Callable, Sequence from functools import partial from itertools import product -from typing import TYPE_CHECKING, Callable, Optional, Union +from typing import TYPE_CHECKING, Optional, Union import numpy as np import pywt @@ -108,7 +108,6 @@ def __init__( >>> viz = np.stack(np_lst).squeeze() >>> plt.imshow(np.abs(viz)) >>> plt.show() - """ self.wavelet = _as_wavelet(wavelet) self.mode = mode @@ -143,6 +142,9 @@ def transform( This allows for partial expansion of the wavelet packet tree. Otherwise, all packet coefficients up to the decomposition level `maxlevel` are computed. Defaults to False. + + Returns: + This wavelet packet object (to allow call chaining). """ self.data = {"": data} if maxlevel is None: @@ -344,7 +346,6 @@ def __init__( If True, the packet tree is initialized lazily. This allows for partial expansion of the wavelet packet tree. Defaults to False. - """ self.wavelet = _as_wavelet(wavelet) self.mode = mode @@ -380,6 +381,9 @@ def transform( This allows for partial expansion of the wavelet packet tree. Otherwise, all packet coefficients up to the decomposition level `maxlevel` are computed. Defaults to False. + + Returns: + This wavelet packet object (to allow call chaining). """ self.data = {"": data} if maxlevel is None: From f7930d2bfb9fabf00d7e7c475ca48208f35d177d Mon Sep 17 00:00:00 2001 From: Felix Blanke Date: Tue, 25 Jun 2024 13:44:19 +0200 Subject: [PATCH 06/21] Adjust docstr --- src/ptwt/packets.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/ptwt/packets.py b/src/ptwt/packets.py index 1c9e96fe..f23c2acc 100644 --- a/src/ptwt/packets.py +++ b/src/ptwt/packets.py @@ -273,7 +273,8 @@ def __getitem__(self, key: str) -> torch.Tensor: Raises: ValueError: If the wavelet packet tree is not initialized. - KeyError: If no wavelet coefficients are indexed by the specified key. + KeyError: If no wavelet coefficients are indexed by the specified key + and a lazy initialization fails. """ if self.maxlevel is None: raise ValueError( @@ -555,7 +556,8 @@ def __getitem__(self, key: str) -> torch.Tensor: Raises: ValueError: If the wavelet packet tree is not initialized. - KeyError: If no wavelet coefficients are indexed by the specified key. + KeyError: If no wavelet coefficients are indexed by the specified key + and a lazy initialization fails. """ if self.maxlevel is None: raise ValueError( From 94bd429eaa810fdcb1cd2097a4084dcaf7d7b5b0 Mon Sep 17 00:00:00 2001 From: Felix Blanke Date: Tue, 25 Jun 2024 13:44:44 +0200 Subject: [PATCH 07/21] Add key check to reconstruct --- src/ptwt/packets.py | 22 ++++++++++++++++++++-- 1 file changed, 20 insertions(+), 2 deletions(-) diff --git a/src/ptwt/packets.py b/src/ptwt/packets.py index f23c2acc..4ff1eae7 100644 --- a/src/ptwt/packets.py +++ b/src/ptwt/packets.py @@ -179,9 +179,18 @@ def reconstruct(self) -> WaveletPacket: for level in reversed(range(self.maxlevel)): for node in self.get_level(level): + # check if any children is not available + # we need to check manually to avoid lazy init + def _test_key(key: str) -> None: + if key not in self: + raise KeyError(f"Key {key} not found") + + for child in ["a", "d"]: + _test_key(node + child) + data_a = self[node + "a"] - data_b = self[node + "d"] - rec = self._get_waverec(data_a.shape[-1])([data_a, data_b]) + data_d = self[node + "d"] + rec = self._get_waverec(data_a.shape[-1])([data_a, data_d]) if level > 0: if rec.shape[-1] != self[node].shape[-1]: assert ( @@ -414,6 +423,15 @@ def reconstruct(self) -> WaveletPacket2D: for level in reversed(range(self.maxlevel)): for node in WaveletPacket2D.get_natural_order(level): + # check if any children is not available + # we need to check manually to avoid lazy init + def _test_key(key: str) -> None: + if key not in self: + raise KeyError(f"Key {key} not found") + + for child in ["a", "h", "v", "d"]: + _test_key(node + child) + data_a = self[node + "a"] data_h = self[node + "h"] data_v = self[node + "v"] From 3a7c7a6f1b157feb6ac8533746ff726fb5ec8cd2 Mon Sep 17 00:00:00 2001 From: Felix Blanke Date: Tue, 25 Jun 2024 13:48:30 +0200 Subject: [PATCH 08/21] Example adjust --- src/ptwt/packets.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/ptwt/packets.py b/src/ptwt/packets.py index 4ff1eae7..1cd3d551 100644 --- a/src/ptwt/packets.py +++ b/src/ptwt/packets.py @@ -102,9 +102,7 @@ def __init__( >>> w = scipy.signal.chirp(t, f0=1, f1=50, t1=10, method="linear") >>> wp = ptwt.WaveletPacket(data=torch.from_numpy(w.astype(np.float32)), >>> wavelet=pywt.Wavelet("db3"), mode="reflect") - >>> np_lst = [] - >>> for node in wp.get_level(5): - >>> np_lst.append(wp[node]) + >>> np_lst = [wp[node] for node in wp.get_level(5)] >>> viz = np.stack(np_lst).squeeze() >>> plt.imshow(np.abs(viz)) >>> plt.show() From 9b7981300a612ef445ead367cfc8b9944ac1e643 Mon Sep 17 00:00:00 2001 From: Felix Blanke Date: Tue, 25 Jun 2024 13:59:26 +0200 Subject: [PATCH 09/21] Add lazy_init to packet tests --- tests/test_packets.py | 66 +++++++++++++++++++++++++++++++++++-------- 1 file changed, 55 insertions(+), 11 deletions(-) diff --git a/tests/test_packets.py b/tests/test_packets.py index 00cf6c4b..58cd3bed 100644 --- a/tests/test_packets.py +++ b/tests/test_packets.py @@ -24,6 +24,7 @@ def _compare_trees1( batch_size: int = 1, transform_mode: bool = False, multiple_transforms: bool = False, + lazy_init: bool = False, ) -> None: data = np.random.rand(batch_size, length) wavelet = pywt.Wavelet(wavelet_str) @@ -31,15 +32,19 @@ def _compare_trees1( if transform_mode: twp = WaveletPacket( None, wavelet, mode=ptwt_boundary, maxlevel=max_lev - ).transform(torch.from_numpy(data), maxlevel=max_lev) + ).transform(torch.from_numpy(data), maxlevel=max_lev, lazy_init=lazy_init) else: twp = WaveletPacket( - torch.from_numpy(data), wavelet, mode=ptwt_boundary, maxlevel=max_lev + torch.from_numpy(data), + wavelet, + mode=ptwt_boundary, + maxlevel=max_lev, + lazy_init=lazy_init, ) # if multiple_transform flag is set, recalculcate the packets if multiple_transforms: - twp.transform(torch.from_numpy(data), maxlevel=max_lev) + twp.transform(torch.from_numpy(data), maxlevel=max_lev, lazy_init=lazy_init) nodes = twp.get_level(twp.maxlevel) twp_lst = [] @@ -76,6 +81,7 @@ def _compare_trees2( batch_size: int = 1, transform_mode: bool = False, multiple_transforms: bool = False, + lazy_init: bool = False, ) -> None: face = datasets.face()[:height, :width] face = np.mean(face, axis=-1).astype(np.float64) @@ -104,15 +110,19 @@ def _compare_trees2( if transform_mode: ptwt_wp_tree = WaveletPacket2D( None, wavelet=wavelet, mode=ptwt_boundary - ).transform(pt_data, maxlevel=max_lev) + ).transform(pt_data, maxlevel=max_lev, lazy_init=lazy_init) else: ptwt_wp_tree = WaveletPacket2D( - pt_data, wavelet=wavelet, mode=ptwt_boundary, maxlevel=max_lev + pt_data, + wavelet=wavelet, + mode=ptwt_boundary, + maxlevel=max_lev, + lazy_init=lazy_init, ) # if multiple_transform flag is set, recalculcate the packets if multiple_transforms: - ptwt_wp_tree.transform(pt_data, maxlevel=max_lev) + ptwt_wp_tree.transform(pt_data, maxlevel=max_lev, lazy_init=lazy_init) packets = [] for node in wp_keys: @@ -132,6 +142,7 @@ def _compare_trees2( @pytest.mark.parametrize("batch_size", [2, 1]) @pytest.mark.parametrize("transform_mode", [False, True]) @pytest.mark.parametrize("multiple_transforms", [False, True]) +@pytest.mark.parametrize("lazy_init", [False, True]) def test_2d_packets( max_lev: Optional[int], wavelet_str: str, @@ -139,6 +150,7 @@ def test_2d_packets( batch_size: int, transform_mode: bool, multiple_transforms: bool, + lazy_init: bool, ) -> None: """Ensure pywt and ptwt produce equivalent wavelet 2d packet trees.""" _compare_trees2( @@ -149,6 +161,7 @@ def test_2d_packets( batch_size=batch_size, transform_mode=transform_mode, multiple_transforms=multiple_transforms, + lazy_init=lazy_init, ) @@ -157,11 +170,13 @@ def test_2d_packets( @pytest.mark.parametrize("batch_size", [1, 2]) @pytest.mark.parametrize("transform_mode", [False, True]) @pytest.mark.parametrize("multiple_transforms", [False, True]) +@pytest.mark.parametrize("lazy_init", [False, True]) def test_boundary_matrix_packets2( max_lev: Optional[int], batch_size: int, transform_mode: bool, multiple_transforms: bool, + lazy_init: bool, ) -> None: """Ensure the 2d - sparse matrix haar tree and pywt-tree are the same.""" _compare_trees2( @@ -172,6 +187,7 @@ def test_boundary_matrix_packets2( batch_size=batch_size, transform_mode=transform_mode, multiple_transforms=multiple_transforms, + lazy_init=lazy_init, ) @@ -184,6 +200,7 @@ def test_boundary_matrix_packets2( @pytest.mark.parametrize("batch_size", [2, 1]) @pytest.mark.parametrize("transform_mode", [False, True]) @pytest.mark.parametrize("multiple_transforms", [False, True]) +@pytest.mark.parametrize("lazy_init", [False, True]) def test_1d_packets( max_lev: int, wavelet_str: str, @@ -191,6 +208,7 @@ def test_1d_packets( batch_size: int, transform_mode: bool, multiple_transforms: bool, + lazy_init: bool, ) -> None: """Ensure pywt and ptwt produce equivalent wavelet 1d packet trees.""" _compare_trees1( @@ -201,6 +219,7 @@ def test_1d_packets( batch_size=batch_size, transform_mode=transform_mode, multiple_transforms=multiple_transforms, + lazy_init=lazy_init, ) @@ -208,8 +227,12 @@ def test_1d_packets( @pytest.mark.parametrize("max_lev", [1, 2, 3, 4, None]) @pytest.mark.parametrize("transform_mode", [False, True]) @pytest.mark.parametrize("multiple_transforms", [False, True]) +@pytest.mark.parametrize("lazy_init", [False, True]) def test_boundary_matrix_packets1( - max_lev: Optional[int], transform_mode: bool, multiple_transforms: bool + max_lev: Optional[int], + transform_mode: bool, + multiple_transforms: bool, + lazy_init: bool, ) -> None: """Ensure the 2d - sparse matrix haar tree and pywt-tree are the same.""" _compare_trees1( @@ -219,6 +242,7 @@ def test_boundary_matrix_packets1( "boundary", transform_mode=transform_mode, multiple_transforms=multiple_transforms, + lazy_init=lazy_init, ) @@ -314,15 +338,26 @@ def test_access_errors_2d() -> None: @pytest.mark.parametrize("shape", [[1, 64, 63], [3, 64, 64], [1, 128]]) @pytest.mark.parametrize("wavelet", ["db1", "db2", "sym4"]) @pytest.mark.parametrize("axis", (1, -1)) +@pytest.mark.parametrize("lazy_init", [True, False]) def test_inverse_packet_1d( - level: int, base_key: str, shape: list[int], wavelet: str, axis: int + level: int, + base_key: str, + shape: list[int], + wavelet: str, + axis: int, + lazy_init: bool, ) -> None: """Test the 1d reconstruction code.""" signal = np.random.randn(*shape) mode = "reflect" wp = pywt.WaveletPacket(signal, wavelet, mode=mode, maxlevel=level, axis=axis) ptwp = WaveletPacket( - torch.from_numpy(signal), wavelet, mode=mode, maxlevel=level, axis=axis + torch.from_numpy(signal), + wavelet, + mode=mode, + maxlevel=level, + axis=axis, + lazy_init=lazy_init, ) wp[base_key * level].data *= 0 ptwp[base_key * level].data *= 0 @@ -337,19 +372,26 @@ def test_inverse_packet_1d( @pytest.mark.parametrize("size", [(32, 32, 32), (32, 32, 31, 64)]) @pytest.mark.parametrize("wavelet", ["db1", "db2", "sym4"]) @pytest.mark.parametrize("axes", [(-2, -1), (-1, -2), (1, 2), (2, 0), (0, 2)]) +@pytest.mark.parametrize("lazy_init", [True, False]) def test_inverse_packet_2d( level: int, base_key: str, size: tuple[int, ...], wavelet: str, axes: tuple[int, int], + lazy_init: bool, ) -> None: """Test the 2d reconstruction code.""" signal = np.random.randn(*size) mode = "reflect" wp = pywt.WaveletPacket2D(signal, wavelet, mode=mode, maxlevel=level, axes=axes) ptwp = WaveletPacket2D( - torch.from_numpy(signal), wavelet, mode=mode, maxlevel=level, axes=axes + torch.from_numpy(signal), + wavelet, + mode=mode, + maxlevel=level, + axes=axes, + lazy_init=lazy_init, ) wp[base_key * level].data *= 0 ptwp[base_key * level].data *= 0 @@ -390,7 +432,8 @@ def test_inverse_boundary_packet_2d() -> None: @pytest.mark.slow @pytest.mark.parametrize("axes", ((-2, -1), (1, 2), (2, 1))) -def test_separable_conv_packets_2d(axes: tuple[int, int]) -> None: +@pytest.mark.parametrize("lazy_init", [True, False]) +def test_separable_conv_packets_2d(axes: tuple[int, int], lazy_init: bool) -> None: """Ensure the 2d separable conv code is ok.""" wavelet = "db2" signal = np.random.randn(1, 32, 32, 32) @@ -401,6 +444,7 @@ def test_separable_conv_packets_2d(axes: tuple[int, int]) -> None: maxlevel=2, axes=axes, separable=True, + lazy_init=lazy_init, ) ptwp.reconstruct() assert np.allclose(signal, ptwp[""].data[:, :32, :32, :32]) From d88cea2119d4067279994d6a97361dff4ca2f6ce Mon Sep 17 00:00:00 2001 From: Felix Blanke Date: Tue, 25 Jun 2024 13:59:45 +0200 Subject: [PATCH 10/21] Avoid calling .data on a tensor --- tests/test_packets.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/test_packets.py b/tests/test_packets.py index 58cd3bed..493a66d1 100644 --- a/tests/test_packets.py +++ b/tests/test_packets.py @@ -360,7 +360,7 @@ def test_inverse_packet_1d( lazy_init=lazy_init, ) wp[base_key * level].data *= 0 - ptwp[base_key * level].data *= 0 + ptwp[base_key * level] *= 0 wp.reconstruct(update=True) ptwp.reconstruct() assert np.allclose(wp[""].data, ptwp[""].numpy()[..., : shape[-2], : shape[-1]]) @@ -394,7 +394,7 @@ def test_inverse_packet_2d( lazy_init=lazy_init, ) wp[base_key * level].data *= 0 - ptwp[base_key * level].data *= 0 + ptwp[base_key * level] *= 0 wp.reconstruct(update=True) ptwp.reconstruct() assert np.allclose(wp[""].data, ptwp[""].numpy()[: size[0], : size[1], : size[2]]) From c8dd24753c6f061ed375285be26e7bab6555807d Mon Sep 17 00:00:00 2001 From: Felix Blanke Date: Tue, 25 Jun 2024 14:27:28 +0200 Subject: [PATCH 11/21] Add unit test for partial expansion --- tests/test_packets.py | 137 ++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 137 insertions(+) diff --git a/tests/test_packets.py b/tests/test_packets.py index 493a66d1..758c69b2 100644 --- a/tests/test_packets.py +++ b/tests/test_packets.py @@ -273,6 +273,143 @@ def test_freq_order(level: int, wavelet_str: str, pywt_boundary: str) -> None: assert order_el.path == "".join(tree_el) +partial_keys_1d = ["aaaa", "aaad", "aad", "ad", "da", "dd"] + +partial_keys_2d = [ + "aaaa", + "aaad", + "aaah", + "aaav", + "aad", + "aah", + "aava", + "aavd", + "aavh", + "aavv", + "ad", + "ah", + "ava", + "avd", + "avh", + "avv", + "d", + "h", + "vaa", + "vad", + "vah", + "vav", + "vd", + "vh", + "vv", +] + + +@pytest.mark.parametrize("wavelet_str", ["haar", "db4"]) +@pytest.mark.parametrize("boundary", ["zero", "reflect", "constant", "boundary"]) +def test_partial_expansion_1d(wavelet_str: str, boundary: str) -> None: + """Test lazy init in 1d.""" + max_lev = 4 + shape = 128 + test_signal = torch.randn(shape) + + lazy_init_packet = WaveletPacket( + test_signal, + wavelet_str, + mode=boundary, + maxlevel=max_lev, + lazy_init=True, + ) + + # Full expansion of the wavelet packet tree + full_keys = lazy_init_packet.get_level(max_lev) + + with pytest.raises(AssertionError): + assert all(key in lazy_init_packet for key in full_keys) + + with pytest.raises(AssertionError): + assert all(key in lazy_init_packet for key in partial_keys_1d) + + # init on partial keys + [lazy_init_packet[key] for key in partial_keys_1d] + + with pytest.raises(AssertionError): + assert all(key in lazy_init_packet for key in full_keys) + + assert all(key in lazy_init_packet for key in partial_keys_1d) + + eager_init_packet = WaveletPacket( + test_signal, + wavelet_str, + mode=boundary, + maxlevel=max_lev, + lazy_init=False, + ) + + assert all(key in eager_init_packet for key in full_keys) + + diffs = [ + ((lazy_init_packet[key] - eager_init_packet[key]) ** 2).sum() + for key in lazy_init_packet.keys() + ] + delta = torch.sum(torch.stack(diffs)) + + assert torch.isclose(delta, torch.tensor(0.0)) + + +@pytest.mark.parametrize("wavelet_str", ["haar", "db4"]) +@pytest.mark.parametrize("boundary", ["zero", "reflect", "constant", "boundary"]) +def test_partial_expansion_2d(wavelet_str: str, boundary: str) -> None: + """Test lazy init in 2d.""" + max_lev = 4 + shape = (128, 128) + test_signal = torch.randn(shape) + + # Full expansion of the wavelet packet tree + full_keys = WaveletPacket2D.get_natural_order(max_lev) + + lazy_init_packet = WaveletPacket2D( + test_signal, + wavelet_str, + mode=boundary, + maxlevel=max_lev, + lazy_init=True, + separable=True, + ) + + with pytest.raises(AssertionError): + assert all(key in lazy_init_packet for key in full_keys) + + with pytest.raises(AssertionError): + assert all(key in lazy_init_packet for key in partial_keys_2d) + + # init on partial keys + [lazy_init_packet[key] for key in partial_keys_2d] + + with pytest.raises(AssertionError): + assert all(key in lazy_init_packet for key in full_keys) + + assert all(key in lazy_init_packet for key in partial_keys_2d) + + eager_init_packet = WaveletPacket2D( + test_signal, + wavelet_str, + mode=boundary, + maxlevel=max_lev, + lazy_init=False, + separable=True, + ) + + assert all(key in eager_init_packet for key in full_keys) + + diffs = [ + ((lazy_init_packet[key] - eager_init_packet[key]) ** 2).sum() + for key in lazy_init_packet.keys() + ] + delta = torch.sum(torch.stack(diffs)) + + assert torch.isclose(delta, torch.tensor(0.0)) + + def test_packet_harbo_lvl3() -> None: """From Jensen, La Cour-Harbo, Rippels in Mathematics, Chapter 8 (page 89).""" data = np.array([56.0, 40.0, 8.0, 24.0, 48.0, 48.0, 40.0, 16.0]) From 20810367fc969239ba3a2088c738a6249bf7190a Mon Sep 17 00:00:00 2001 From: Felix Blanke Date: Tue, 25 Jun 2024 14:43:20 +0200 Subject: [PATCH 12/21] Fix reconstruct tests for lazy init --- tests/test_packets.py | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/tests/test_packets.py b/tests/test_packets.py index 758c69b2..c5f2a691 100644 --- a/tests/test_packets.py +++ b/tests/test_packets.py @@ -496,6 +496,13 @@ def test_inverse_packet_1d( axis=axis, lazy_init=lazy_init, ) + if lazy_init: + with pytest.raises(KeyError): + ptwp.reconstruct() + + # lazy init + [ptwp[key] for key in ptwp.get_level(level)] + wp[base_key * level].data *= 0 ptwp[base_key * level] *= 0 wp.reconstruct(update=True) @@ -531,6 +538,13 @@ def test_inverse_packet_2d( lazy_init=lazy_init, ) wp[base_key * level].data *= 0 + if lazy_init: + with pytest.raises(KeyError): + ptwp.reconstruct() + + # lazy init + [ptwp[key] for key in ptwp.get_natural_order(level)] + ptwp[base_key * level] *= 0 wp.reconstruct(update=True) ptwp.reconstruct() @@ -583,5 +597,11 @@ def test_separable_conv_packets_2d(axes: tuple[int, int], lazy_init: bool) -> No separable=True, lazy_init=lazy_init, ) + if lazy_init: + with pytest.raises(KeyError): + ptwp.reconstruct() + + # lazy init + [ptwp[key] for key in ptwp.get_natural_order(2)] ptwp.reconstruct() assert np.allclose(signal, ptwp[""].data[:, :32, :32, :32]) From 8124814c5ead3dc43121fce1e3323d6ecf3aa2a1 Mon Sep 17 00:00:00 2001 From: Moritz Wolter Date: Wed, 26 Jun 2024 18:00:19 +0200 Subject: [PATCH 13/21] remove argument. --- src/ptwt/packets.py | 42 +-------------------- tests/test_packets.py | 86 ++++++++++++++----------------------------- 2 files changed, 30 insertions(+), 98 deletions(-) diff --git a/src/ptwt/packets.py b/src/ptwt/packets.py index 0ba943ab..92743968 100644 --- a/src/ptwt/packets.py +++ b/src/ptwt/packets.py @@ -62,7 +62,6 @@ def __init__( maxlevel: Optional[int] = None, axis: int = -1, boundary_orthogonalization: OrthogonalizeMethod = "qr", - lazy_init: bool = False, ) -> None: """Create a wavelet packet decomposition object. @@ -89,10 +88,6 @@ def __init__( to use in the sparse matrix backend, see :data:`ptwt.constants.OrthogonalizeMethod`. Only used if `mode` equals 'boundary'. Defaults to 'qr'. - lazy_init (bool): Value is passed on to :func:`transform`. - If True, the packet tree is initialized lazily. This - allows for partial expansion of the wavelet packet tree. - Defaults to False. Example: >>> import torch, pywt, ptwt @@ -116,7 +111,7 @@ def __init__( self.maxlevel: Optional[int] = None self.axis = axis if data is not None: - self.transform(data, maxlevel, lazy_init=lazy_init) + self.transform(data, maxlevel) else: self.data = {} @@ -124,7 +119,6 @@ def transform( self, data: torch.Tensor, maxlevel: Optional[int] = None, - lazy_init: bool = False, ) -> WaveletPacket: """Calculate the 1d wavelet packet transform for the input data. @@ -134,10 +128,6 @@ def transform( maxlevel (int, optional): The highest decomposition level to compute. If None, the maximum level is determined from the input data shape. Defaults to None. - lazy_init (bool): If True, the packet tree is initialized lazily. - This allows for partial expansion of the wavelet packet tree. - Otherwise, all packet coefficients up to the decomposition level - `maxlevel` are computed. Defaults to False. Returns: This wavelet packet object (to allow call chaining). @@ -146,8 +136,6 @@ def transform( if maxlevel is None: maxlevel = pywt.dwt_max_level(data.shape[self.axis], self.wavelet.dec_len) self.maxlevel = maxlevel - if not lazy_init: - self._recursive_dwt(path="") return self def reconstruct(self) -> WaveletPacket: @@ -270,19 +258,6 @@ def _expand_node(self, path: str) -> None: self.data[path + "a"] = res_lo self.data[path + "d"] = res_hi - def _recursive_dwt(self, path: str) -> None: - if self.maxlevel is None: - raise AssertionError - - if len(path) >= self.maxlevel: - # nothing to expand - return - - self._expand_node(path) - - for child in ["a", "d"]: - self._recursive_dwt(path + child) - def __getitem__(self, key: str) -> torch.Tensor: """Access the coefficients in the wavelet packets tree. @@ -338,7 +313,6 @@ def __init__( axes: tuple[int, int] = (-2, -1), boundary_orthogonalization: OrthogonalizeMethod = "qr", separable: bool = False, - lazy_init: bool = False, ) -> None: """Create a 2D-Wavelet packet tree. @@ -366,10 +340,6 @@ def __init__( Only used if `mode` equals 'boundary'. Defaults to 'qr'. separable (bool): If true, a separable transform is performed, i.e. each image axis is transformed separately. Defaults to False. - lazy_init (bool): Value is passed on to :func:`transform`. - If True, the packet tree is initialized lazily. This - allows for partial expansion of the wavelet packet tree. - Defaults to False. """ self.wavelet = _as_wavelet(wavelet) self.mode = mode @@ -381,7 +351,7 @@ def __init__( self.maxlevel: Optional[int] = None if data is not None: - self.transform(data, maxlevel, lazy_init=lazy_init) + self.transform(data, maxlevel) else: self.data = {} @@ -389,7 +359,6 @@ def transform( self, data: torch.Tensor, maxlevel: Optional[int] = None, - lazy_init: bool = False, ) -> WaveletPacket2D: """Calculate the 2d wavelet packet transform for the input data. @@ -401,10 +370,6 @@ def transform( maxlevel (int, optional): The highest decomposition level to compute. If None, the maximum level is determined from the input data shape. Defaults to None. - lazy_init (bool): If True, the packet tree is initialized lazily. - This allows for partial expansion of the wavelet packet tree. - Otherwise, all packet coefficients up to the decomposition level - `maxlevel` are computed. Defaults to False. Returns: This wavelet packet object (to allow call chaining). @@ -414,9 +379,6 @@ def transform( min_transform_size = min(_swap_axes(data, self.axes).shape[-2:]) maxlevel = pywt.dwt_max_level(min_transform_size, self.wavelet.dec_len) self.maxlevel = maxlevel - - if not lazy_init: - self._recursive_dwt2d(path="") return self def reconstruct(self) -> WaveletPacket2D: diff --git a/tests/test_packets.py b/tests/test_packets.py index 564d0a8f..2ddd317b 100644 --- a/tests/test_packets.py +++ b/tests/test_packets.py @@ -25,14 +25,13 @@ def _compare_trees1( transform_mode: bool = False, multiple_transforms: bool = False, axis: int = -1, - lazy_init: bool = False, ) -> None: data = np.random.rand(batch_size, length) data = data.swapaxes(axis, -1) if transform_mode: twp = WaveletPacket(None, wavelet_str, mode=ptwt_boundary, axis=axis).transform( - torch.from_numpy(data), maxlevel=max_lev, lazy_init=lazy_init + torch.from_numpy(data), maxlevel=max_lev ) else: twp = WaveletPacket( @@ -41,12 +40,11 @@ def _compare_trees1( mode=ptwt_boundary, maxlevel=max_lev, axis=axis, - lazy_init=lazy_init, ) # if multiple_transform flag is set, recalculcate the packets if multiple_transforms: - twp.transform(torch.from_numpy(data), maxlevel=max_lev, lazy_init=lazy_init) + twp.transform(torch.from_numpy(data), maxlevel=max_lev) torch_res = torch.cat([twp[node] for node in twp.get_level(twp.maxlevel)], axis) @@ -76,7 +74,6 @@ def _compare_trees2( transform_mode: bool = False, multiple_transforms: bool = False, axes: tuple[int, int] = (-2, -1), - lazy_init: bool = False, ) -> None: face = datasets.face()[:height, :width].astype(np.float64).mean(-1) data = torch.stack([torch.from_numpy(face)] * batch_size, 0) @@ -103,20 +100,19 @@ def _compare_trees2( if transform_mode: ptwt_wp_tree = WaveletPacket2D( None, wavelet=wavelet_str, mode=ptwt_boundary, axes=axes - ).transform(data, maxlevel=max_lev, lazy_init=lazy_init) + ).transform(data, maxlevel=max_lev) else: ptwt_wp_tree = WaveletPacket2D( data, wavelet=wavelet_str, mode=ptwt_boundary, maxlevel=max_lev, - axes=axes, - lazy_init=lazy_init, + axes=axes ) # if multiple_transform flag is set, recalculcate the packets if multiple_transforms: - ptwt_wp_tree.transform(data, maxlevel=max_lev, lazy_init=lazy_init) + ptwt_wp_tree.transform(data, maxlevel=max_lev) packets_pt = torch.stack( [ @@ -140,7 +136,6 @@ def _compare_trees2( @pytest.mark.parametrize("transform_mode", [False, True]) @pytest.mark.parametrize("multiple_transforms", [False, True]) @pytest.mark.parametrize("axes", [(-2, -1), (-1, -2), (1, 2), (2, 0), (0, 2)]) -@pytest.mark.parametrize("lazy_init", [False, True]) def test_2d_packets( max_lev: Optional[int], wavelet_str: str, @@ -148,8 +143,7 @@ def test_2d_packets( batch_size: int, transform_mode: bool, multiple_transforms: bool, - axes: tuple[int, int], - lazy_init: bool, + axes: tuple[int, int] ) -> None: """Ensure pywt and ptwt produce equivalent wavelet 2d packet trees.""" _compare_trees2( @@ -160,8 +154,7 @@ def test_2d_packets( batch_size=batch_size, transform_mode=transform_mode, multiple_transforms=multiple_transforms, - axes=axes, - lazy_init=lazy_init, + axes=axes ) @@ -171,14 +164,12 @@ def test_2d_packets( @pytest.mark.parametrize("transform_mode", [False, True]) @pytest.mark.parametrize("multiple_transforms", [False, True]) @pytest.mark.parametrize("axes", [(-2, -1), (-1, -2), (1, 2), (2, 0), (0, 2)]) -@pytest.mark.parametrize("lazy_init", [False, True]) def test_boundary_matrix_packets2( max_lev: Optional[int], batch_size: int, transform_mode: bool, multiple_transforms: bool, axes: tuple[int, int], - lazy_init: bool, ) -> None: """Ensure the 2d - sparse matrix haar tree and pywt-tree are the same.""" _compare_trees2( @@ -189,8 +180,7 @@ def test_boundary_matrix_packets2( batch_size=batch_size, transform_mode=transform_mode, multiple_transforms=multiple_transforms, - axes=axes, - lazy_init=lazy_init, + axes=axes ) @@ -204,7 +194,6 @@ def test_boundary_matrix_packets2( @pytest.mark.parametrize("transform_mode", [False, True]) @pytest.mark.parametrize("multiple_transforms", [False, True]) @pytest.mark.parametrize("axis", [0, -1]) -@pytest.mark.parametrize("lazy_init", [False, True]) def test_1d_packets( max_lev: int, wavelet_str: str, @@ -212,8 +201,7 @@ def test_1d_packets( batch_size: int, transform_mode: bool, multiple_transforms: bool, - axis: int, - lazy_init: bool, + axis: int ) -> None: """Ensure pywt and ptwt produce equivalent wavelet 1d packet trees.""" _compare_trees1( @@ -225,7 +213,6 @@ def test_1d_packets( transform_mode=transform_mode, multiple_transforms=multiple_transforms, axis=axis, - lazy_init=lazy_init, ) @@ -233,12 +220,10 @@ def test_1d_packets( @pytest.mark.parametrize("max_lev", [1, 2, 3, 4, None]) @pytest.mark.parametrize("transform_mode", [False, True]) @pytest.mark.parametrize("multiple_transforms", [False, True]) -@pytest.mark.parametrize("lazy_init", [False, True]) def test_boundary_matrix_packets1( max_lev: Optional[int], transform_mode: bool, - multiple_transforms: bool, - lazy_init: bool, + multiple_transforms: bool ) -> None: """Ensure the 2d - sparse matrix haar tree and pywt-tree are the same.""" _compare_trees1( @@ -247,8 +232,7 @@ def test_boundary_matrix_packets1( "zero", "boundary", transform_mode=transform_mode, - multiple_transforms=multiple_transforms, - lazy_init=lazy_init, + multiple_transforms=multiple_transforms ) @@ -356,8 +340,7 @@ def test_partial_expansion_1d(wavelet_str: str, boundary: str) -> None: test_signal, wavelet_str, mode=boundary, - maxlevel=max_lev, - lazy_init=True, + maxlevel=max_lev ) # Full expansion of the wavelet packet tree @@ -382,7 +365,6 @@ def test_partial_expansion_1d(wavelet_str: str, boundary: str) -> None: wavelet_str, mode=boundary, maxlevel=max_lev, - lazy_init=False, ) assert all(key in eager_init_packet for key in full_keys) @@ -412,7 +394,6 @@ def test_partial_expansion_2d(wavelet_str: str, boundary: str) -> None: wavelet_str, mode=boundary, maxlevel=max_lev, - lazy_init=True, separable=True, ) @@ -435,7 +416,6 @@ def test_partial_expansion_2d(wavelet_str: str, boundary: str) -> None: wavelet_str, mode=boundary, maxlevel=max_lev, - lazy_init=False, separable=True, ) @@ -508,14 +488,12 @@ def test_access_errors_2d() -> None: @pytest.mark.parametrize("shape", [[1, 64, 63], [3, 64, 64], [1, 128]]) @pytest.mark.parametrize("wavelet", ["db1", "db2", "sym4"]) @pytest.mark.parametrize("axis", (1, -1)) -@pytest.mark.parametrize("lazy_init", [True, False]) def test_inverse_packet_1d( level: int, base_key: str, shape: list[int], wavelet: str, - axis: int, - lazy_init: bool, + axis: int ) -> None: """Test the 1d reconstruction code.""" signal = np.random.randn(*shape) @@ -526,15 +504,13 @@ def test_inverse_packet_1d( wavelet, mode=mode, maxlevel=level, - axis=axis, - lazy_init=lazy_init, + axis=axis ) - if lazy_init: - with pytest.raises(KeyError): - ptwp.reconstruct() - # lazy init - [ptwp[key] for key in ptwp.get_level(level)] + with pytest.raises(KeyError): + ptwp.reconstruct() + # lazy init + [ptwp[key] for key in ptwp.get_level(level)] wp[base_key * level].data *= 0 ptwp[base_key * level] *= 0 @@ -549,14 +525,12 @@ def test_inverse_packet_1d( @pytest.mark.parametrize("size", [(32, 32, 32), (32, 32, 31, 64)]) @pytest.mark.parametrize("wavelet", ["db1", "db2", "sym4"]) @pytest.mark.parametrize("axes", [(-2, -1), (-1, -2), (1, 2), (2, 0), (0, 2)]) -@pytest.mark.parametrize("lazy_init", [True, False]) def test_inverse_packet_2d( level: int, base_key: str, size: tuple[int, ...], wavelet: str, axes: tuple[int, int], - lazy_init: bool, ) -> None: """Test the 2d reconstruction code.""" signal = np.random.randn(*size) @@ -568,15 +542,14 @@ def test_inverse_packet_2d( mode=mode, maxlevel=level, axes=axes, - lazy_init=lazy_init, ) wp[base_key * level].data *= 0 - if lazy_init: - with pytest.raises(KeyError): - ptwp.reconstruct() - # lazy init - [ptwp[key] for key in ptwp.get_natural_order(level)] + with pytest.raises(KeyError): + ptwp.reconstruct() + + # lazy init + [ptwp[key] for key in ptwp.get_natural_order(level)] ptwp[base_key * level] *= 0 wp.reconstruct(update=True) @@ -616,8 +589,7 @@ def test_inverse_boundary_packet_2d() -> None: @pytest.mark.slow @pytest.mark.parametrize("axes", ((-2, -1), (1, 2), (2, 1))) -@pytest.mark.parametrize("lazy_init", [True, False]) -def test_separable_conv_packets_2d(axes: tuple[int, int], lazy_init: bool) -> None: +def test_separable_conv_packets_2d(axes: tuple[int, int]) -> None: """Ensure the 2d separable conv code is ok.""" wavelet = "db2" signal = np.random.randn(1, 32, 32, 32) @@ -628,13 +600,11 @@ def test_separable_conv_packets_2d(axes: tuple[int, int], lazy_init: bool) -> No maxlevel=2, axes=axes, separable=True, - lazy_init=lazy_init, ) - if lazy_init: - with pytest.raises(KeyError): - ptwp.reconstruct() + with pytest.raises(KeyError): + ptwp.reconstruct() - # lazy init - [ptwp[key] for key in ptwp.get_natural_order(2)] + # lazy init + [ptwp[key] for key in ptwp.get_natural_order(2)] ptwp.reconstruct() assert np.allclose(signal, ptwp[""].data[:, :32, :32, :32]) From 6b62e0ff74a0c4183df568e209402768caf54f0a Mon Sep 17 00:00:00 2001 From: Felix Blanke Date: Wed, 26 Jun 2024 17:51:44 +0200 Subject: [PATCH 14/21] Make lazy init the default packet behavior --- src/ptwt/packets.py | 89 +++++++++++++----------------------- tests/test_packets.py | 102 ++++++++++-------------------------------- 2 files changed, 55 insertions(+), 136 deletions(-) diff --git a/src/ptwt/packets.py b/src/ptwt/packets.py index 0ba943ab..c01825f5 100644 --- a/src/ptwt/packets.py +++ b/src/ptwt/packets.py @@ -3,7 +3,7 @@ from __future__ import annotations import collections -from collections.abc import Callable, Sequence +from collections.abc import Callable, Iterable, Sequence from functools import partial from itertools import product from typing import TYPE_CHECKING, Literal, Optional, Union, overload @@ -62,11 +62,12 @@ def __init__( maxlevel: Optional[int] = None, axis: int = -1, boundary_orthogonalization: OrthogonalizeMethod = "qr", - lazy_init: bool = False, ) -> None: """Create a wavelet packet decomposition object. - The decompositions will rely on padded fast wavelet transforms. + The packet tree is initialized lazily, i.e. a coefficient is only + calculated as it is retrieved. This allows for partial expansion + of the wavelet packet tree. Args: data (torch.Tensor, optional): The input data array of shape ``[time]``, @@ -89,10 +90,6 @@ def __init__( to use in the sparse matrix backend, see :data:`ptwt.constants.OrthogonalizeMethod`. Only used if `mode` equals 'boundary'. Defaults to 'qr'. - lazy_init (bool): Value is passed on to :func:`transform`. - If True, the packet tree is initialized lazily. This - allows for partial expansion of the wavelet packet tree. - Defaults to False. Example: >>> import torch, pywt, ptwt @@ -116,7 +113,7 @@ def __init__( self.maxlevel: Optional[int] = None self.axis = axis if data is not None: - self.transform(data, maxlevel, lazy_init=lazy_init) + self.transform(data, maxlevel) else: self.data = {} @@ -124,9 +121,14 @@ def transform( self, data: torch.Tensor, maxlevel: Optional[int] = None, - lazy_init: bool = False, ) -> WaveletPacket: - """Calculate the 1d wavelet packet transform for the input data. + """Lazily calculate the 1d wavelet packet transform for the input data. + + The packet tree is initialized lazily, i.e. a coefficient is only + calculated as it is retrieved. This allows for partial expansion + of the wavelet packet tree. + + The transform function allows reusing the same object. Args: data (torch.Tensor): The input data array of shape ``[time]`` @@ -134,10 +136,6 @@ def transform( maxlevel (int, optional): The highest decomposition level to compute. If None, the maximum level is determined from the input data shape. Defaults to None. - lazy_init (bool): If True, the packet tree is initialized lazily. - This allows for partial expansion of the wavelet packet tree. - Otherwise, all packet coefficients up to the decomposition level - `maxlevel` are computed. Defaults to False. Returns: This wavelet packet object (to allow call chaining). @@ -146,8 +144,6 @@ def transform( if maxlevel is None: maxlevel = pywt.dwt_max_level(data.shape[self.axis], self.wavelet.dec_len) self.maxlevel = maxlevel - if not lazy_init: - self._recursive_dwt(path="") return self def reconstruct(self) -> WaveletPacket: @@ -166,9 +162,14 @@ def reconstruct(self) -> WaveletPacket: >>> signal = np.random.randn(1, 16) >>> ptwp = ptwt.WaveletPacket(torch.from_numpy(signal), "haar", >>> mode="boundary", maxlevel=2) - >>> ptwp["aa"].data *= 0 + >>> # initialize other leaf nodes + >>> ptwp.initialize(["ad", "da", "dd"]) + >>> ptwp["aa"] = torch.zeros_like(ptwp["ad"]) >>> ptwp.reconstruct() >>> print(ptwp[""]) + + Raises: + KeyError: if any leaf node data is not present. """ if self.maxlevel is None: self.maxlevel = pywt.dwt_max_level(self[""].shape[-1], self.wavelet.dec_len) @@ -270,19 +271,6 @@ def _expand_node(self, path: str) -> None: self.data[path + "a"] = res_lo self.data[path + "d"] = res_hi - def _recursive_dwt(self, path: str) -> None: - if self.maxlevel is None: - raise AssertionError - - if len(path) >= self.maxlevel: - # nothing to expand - return - - self._expand_node(path) - - for child in ["a", "d"]: - self._recursive_dwt(path + child) - def __getitem__(self, key: str) -> torch.Tensor: """Access the coefficients in the wavelet packets tree. @@ -338,10 +326,13 @@ def __init__( axes: tuple[int, int] = (-2, -1), boundary_orthogonalization: OrthogonalizeMethod = "qr", separable: bool = False, - lazy_init: bool = False, ) -> None: """Create a 2D-Wavelet packet tree. + The packet tree is initialized lazily, i.e. a coefficient is only + calculated as it is retrieved. This allows for partial expansion + of the wavelet packet tree. + Args: data (torch.tensor, optional): The input data tensor. For example of shape ``[batch_size, height, width]`` or @@ -366,10 +357,6 @@ def __init__( Only used if `mode` equals 'boundary'. Defaults to 'qr'. separable (bool): If true, a separable transform is performed, i.e. each image axis is transformed separately. Defaults to False. - lazy_init (bool): Value is passed on to :func:`transform`. - If True, the packet tree is initialized lazily. This - allows for partial expansion of the wavelet packet tree. - Defaults to False. """ self.wavelet = _as_wavelet(wavelet) self.mode = mode @@ -381,7 +368,7 @@ def __init__( self.maxlevel: Optional[int] = None if data is not None: - self.transform(data, maxlevel, lazy_init=lazy_init) + self.transform(data, maxlevel) else: self.data = {} @@ -389,11 +376,14 @@ def transform( self, data: torch.Tensor, maxlevel: Optional[int] = None, - lazy_init: bool = False, ) -> WaveletPacket2D: - """Calculate the 2d wavelet packet transform for the input data. + """Lazily calculate the 2d wavelet packet transform for the input data. + + The packet tree is initialized lazily, i.e. a coefficient is only + calculated as it is retrieved. This allows for partial expansion + of the wavelet packet tree. - The transform function allows reusing the same object. + The transform function allows reusing the same object. Args: data (torch.tensor): The input data tensor @@ -401,10 +391,6 @@ def transform( maxlevel (int, optional): The highest decomposition level to compute. If None, the maximum level is determined from the input data shape. Defaults to None. - lazy_init (bool): If True, the packet tree is initialized lazily. - This allows for partial expansion of the wavelet packet tree. - Otherwise, all packet coefficients up to the decomposition level - `maxlevel` are computed. Defaults to False. Returns: This wavelet packet object (to allow call chaining). @@ -415,8 +401,6 @@ def transform( maxlevel = pywt.dwt_max_level(min_transform_size, self.wavelet.dec_len) self.maxlevel = maxlevel - if not lazy_init: - self._recursive_dwt2d(path="") return self def reconstruct(self) -> WaveletPacket2D: @@ -426,6 +410,9 @@ def reconstruct(self) -> WaveletPacket2D: Only changes to leaf node data impact the results, since changes in all other nodes will be replaced with a reconstruction from the leaves. + + Raises: + KeyError: if any leaf node data is not present. """ if self.maxlevel is None: min_transform_size = min(_swap_axes(self[""], self.axes).shape[-2:]) @@ -561,18 +548,6 @@ def _fsdict_func(coeffs: WaveletCoeff2d) -> torch.Tensor: return _fsdict_func - def _recursive_dwt2d(self, path: str) -> None: - if self.maxlevel is None: - raise AssertionError - - if len(path) >= self.maxlevel: - # nothing to expand - return - - self._expand_node(path) - for child in ["a", "h", "v", "d"]: - self._recursive_dwt2d(path + child) - def __getitem__(self, key: str) -> torch.Tensor: """Access the coefficients in the wavelet packets tree. diff --git a/tests/test_packets.py b/tests/test_packets.py index 564d0a8f..b41e50d4 100644 --- a/tests/test_packets.py +++ b/tests/test_packets.py @@ -25,14 +25,13 @@ def _compare_trees1( transform_mode: bool = False, multiple_transforms: bool = False, axis: int = -1, - lazy_init: bool = False, ) -> None: data = np.random.rand(batch_size, length) data = data.swapaxes(axis, -1) if transform_mode: twp = WaveletPacket(None, wavelet_str, mode=ptwt_boundary, axis=axis).transform( - torch.from_numpy(data), maxlevel=max_lev, lazy_init=lazy_init + torch.from_numpy(data), maxlevel=max_lev ) else: twp = WaveletPacket( @@ -41,12 +40,11 @@ def _compare_trees1( mode=ptwt_boundary, maxlevel=max_lev, axis=axis, - lazy_init=lazy_init, ) # if multiple_transform flag is set, recalculcate the packets if multiple_transforms: - twp.transform(torch.from_numpy(data), maxlevel=max_lev, lazy_init=lazy_init) + twp.transform(torch.from_numpy(data), maxlevel=max_lev) torch_res = torch.cat([twp[node] for node in twp.get_level(twp.maxlevel)], axis) @@ -76,7 +74,6 @@ def _compare_trees2( transform_mode: bool = False, multiple_transforms: bool = False, axes: tuple[int, int] = (-2, -1), - lazy_init: bool = False, ) -> None: face = datasets.face()[:height, :width].astype(np.float64).mean(-1) data = torch.stack([torch.from_numpy(face)] * batch_size, 0) @@ -103,7 +100,7 @@ def _compare_trees2( if transform_mode: ptwt_wp_tree = WaveletPacket2D( None, wavelet=wavelet_str, mode=ptwt_boundary, axes=axes - ).transform(data, maxlevel=max_lev, lazy_init=lazy_init) + ).transform(data, maxlevel=max_lev) else: ptwt_wp_tree = WaveletPacket2D( data, @@ -111,12 +108,11 @@ def _compare_trees2( mode=ptwt_boundary, maxlevel=max_lev, axes=axes, - lazy_init=lazy_init, ) # if multiple_transform flag is set, recalculcate the packets if multiple_transforms: - ptwt_wp_tree.transform(data, maxlevel=max_lev, lazy_init=lazy_init) + ptwt_wp_tree.transform(data, maxlevel=max_lev) packets_pt = torch.stack( [ @@ -140,7 +136,6 @@ def _compare_trees2( @pytest.mark.parametrize("transform_mode", [False, True]) @pytest.mark.parametrize("multiple_transforms", [False, True]) @pytest.mark.parametrize("axes", [(-2, -1), (-1, -2), (1, 2), (2, 0), (0, 2)]) -@pytest.mark.parametrize("lazy_init", [False, True]) def test_2d_packets( max_lev: Optional[int], wavelet_str: str, @@ -149,7 +144,6 @@ def test_2d_packets( transform_mode: bool, multiple_transforms: bool, axes: tuple[int, int], - lazy_init: bool, ) -> None: """Ensure pywt and ptwt produce equivalent wavelet 2d packet trees.""" _compare_trees2( @@ -161,7 +155,6 @@ def test_2d_packets( transform_mode=transform_mode, multiple_transforms=multiple_transforms, axes=axes, - lazy_init=lazy_init, ) @@ -171,14 +164,12 @@ def test_2d_packets( @pytest.mark.parametrize("transform_mode", [False, True]) @pytest.mark.parametrize("multiple_transforms", [False, True]) @pytest.mark.parametrize("axes", [(-2, -1), (-1, -2), (1, 2), (2, 0), (0, 2)]) -@pytest.mark.parametrize("lazy_init", [False, True]) def test_boundary_matrix_packets2( max_lev: Optional[int], batch_size: int, transform_mode: bool, multiple_transforms: bool, axes: tuple[int, int], - lazy_init: bool, ) -> None: """Ensure the 2d - sparse matrix haar tree and pywt-tree are the same.""" _compare_trees2( @@ -190,7 +181,6 @@ def test_boundary_matrix_packets2( transform_mode=transform_mode, multiple_transforms=multiple_transforms, axes=axes, - lazy_init=lazy_init, ) @@ -204,7 +194,6 @@ def test_boundary_matrix_packets2( @pytest.mark.parametrize("transform_mode", [False, True]) @pytest.mark.parametrize("multiple_transforms", [False, True]) @pytest.mark.parametrize("axis", [0, -1]) -@pytest.mark.parametrize("lazy_init", [False, True]) def test_1d_packets( max_lev: int, wavelet_str: str, @@ -213,7 +202,6 @@ def test_1d_packets( transform_mode: bool, multiple_transforms: bool, axis: int, - lazy_init: bool, ) -> None: """Ensure pywt and ptwt produce equivalent wavelet 1d packet trees.""" _compare_trees1( @@ -225,7 +213,6 @@ def test_1d_packets( transform_mode=transform_mode, multiple_transforms=multiple_transforms, axis=axis, - lazy_init=lazy_init, ) @@ -233,12 +220,10 @@ def test_1d_packets( @pytest.mark.parametrize("max_lev", [1, 2, 3, 4, None]) @pytest.mark.parametrize("transform_mode", [False, True]) @pytest.mark.parametrize("multiple_transforms", [False, True]) -@pytest.mark.parametrize("lazy_init", [False, True]) def test_boundary_matrix_packets1( max_lev: Optional[int], transform_mode: bool, multiple_transforms: bool, - lazy_init: bool, ) -> None: """Ensure the 2d - sparse matrix haar tree and pywt-tree are the same.""" _compare_trees1( @@ -248,7 +233,6 @@ def test_boundary_matrix_packets1( "boundary", transform_mode=transform_mode, multiple_transforms=multiple_transforms, - lazy_init=lazy_init, ) @@ -357,7 +341,6 @@ def test_partial_expansion_1d(wavelet_str: str, boundary: str) -> None: wavelet_str, mode=boundary, maxlevel=max_lev, - lazy_init=True, ) # Full expansion of the wavelet packet tree @@ -377,23 +360,10 @@ def test_partial_expansion_1d(wavelet_str: str, boundary: str) -> None: assert all(key in lazy_init_packet for key in partial_keys_1d) - eager_init_packet = WaveletPacket( - test_signal, - wavelet_str, - mode=boundary, - maxlevel=max_lev, - lazy_init=False, - ) + # init on full keys + [lazy_init_packet[key] for key in full_keys] - assert all(key in eager_init_packet for key in full_keys) - - diffs = [ - ((lazy_init_packet[key] - eager_init_packet[key]) ** 2).sum() - for key in lazy_init_packet.keys() - ] - delta = torch.sum(torch.stack(diffs)) - - assert torch.isclose(delta, torch.tensor(0.0)) + assert all(key in lazy_init_packet for key in full_keys) @pytest.mark.parametrize("wavelet_str", ["haar", "db4"]) @@ -412,7 +382,6 @@ def test_partial_expansion_2d(wavelet_str: str, boundary: str) -> None: wavelet_str, mode=boundary, maxlevel=max_lev, - lazy_init=True, separable=True, ) @@ -430,24 +399,10 @@ def test_partial_expansion_2d(wavelet_str: str, boundary: str) -> None: assert all(key in lazy_init_packet for key in partial_keys_2d) - eager_init_packet = WaveletPacket2D( - test_signal, - wavelet_str, - mode=boundary, - maxlevel=max_lev, - lazy_init=False, - separable=True, - ) - - assert all(key in eager_init_packet for key in full_keys) + # init on full keys + [lazy_init_packet[key] for key in full_keys] - diffs = [ - ((lazy_init_packet[key] - eager_init_packet[key]) ** 2).sum() - for key in lazy_init_packet.keys() - ] - delta = torch.sum(torch.stack(diffs)) - - assert torch.isclose(delta, torch.tensor(0.0)) + assert all(key in lazy_init_packet for key in full_keys) def test_packet_harbo_lvl3() -> None: @@ -508,14 +463,12 @@ def test_access_errors_2d() -> None: @pytest.mark.parametrize("shape", [[1, 64, 63], [3, 64, 64], [1, 128]]) @pytest.mark.parametrize("wavelet", ["db1", "db2", "sym4"]) @pytest.mark.parametrize("axis", (1, -1)) -@pytest.mark.parametrize("lazy_init", [True, False]) def test_inverse_packet_1d( level: int, base_key: str, shape: list[int], wavelet: str, axis: int, - lazy_init: bool, ) -> None: """Test the 1d reconstruction code.""" signal = np.random.randn(*shape) @@ -527,14 +480,12 @@ def test_inverse_packet_1d( mode=mode, maxlevel=level, axis=axis, - lazy_init=lazy_init, ) - if lazy_init: - with pytest.raises(KeyError): - ptwp.reconstruct() + with pytest.raises(KeyError): + ptwp.reconstruct() - # lazy init - [ptwp[key] for key in ptwp.get_level(level)] + # lazy init + [ptwp[key] for key in ptwp.get_level(level)] wp[base_key * level].data *= 0 ptwp[base_key * level] *= 0 @@ -549,14 +500,12 @@ def test_inverse_packet_1d( @pytest.mark.parametrize("size", [(32, 32, 32), (32, 32, 31, 64)]) @pytest.mark.parametrize("wavelet", ["db1", "db2", "sym4"]) @pytest.mark.parametrize("axes", [(-2, -1), (-1, -2), (1, 2), (2, 0), (0, 2)]) -@pytest.mark.parametrize("lazy_init", [True, False]) def test_inverse_packet_2d( level: int, base_key: str, size: tuple[int, ...], wavelet: str, axes: tuple[int, int], - lazy_init: bool, ) -> None: """Test the 2d reconstruction code.""" signal = np.random.randn(*size) @@ -568,15 +517,13 @@ def test_inverse_packet_2d( mode=mode, maxlevel=level, axes=axes, - lazy_init=lazy_init, ) wp[base_key * level].data *= 0 - if lazy_init: - with pytest.raises(KeyError): - ptwp.reconstruct() + with pytest.raises(KeyError): + ptwp.reconstruct() - # lazy init - [ptwp[key] for key in ptwp.get_natural_order(level)] + # lazy init + [ptwp[key] for key in ptwp.get_natural_order(level)] ptwp[base_key * level] *= 0 wp.reconstruct(update=True) @@ -616,8 +563,7 @@ def test_inverse_boundary_packet_2d() -> None: @pytest.mark.slow @pytest.mark.parametrize("axes", ((-2, -1), (1, 2), (2, 1))) -@pytest.mark.parametrize("lazy_init", [True, False]) -def test_separable_conv_packets_2d(axes: tuple[int, int], lazy_init: bool) -> None: +def test_separable_conv_packets_2d(axes: tuple[int, int]) -> None: """Ensure the 2d separable conv code is ok.""" wavelet = "db2" signal = np.random.randn(1, 32, 32, 32) @@ -628,13 +574,11 @@ def test_separable_conv_packets_2d(axes: tuple[int, int], lazy_init: bool) -> No maxlevel=2, axes=axes, separable=True, - lazy_init=lazy_init, ) - if lazy_init: - with pytest.raises(KeyError): - ptwp.reconstruct() + with pytest.raises(KeyError): + ptwp.reconstruct() - # lazy init - [ptwp[key] for key in ptwp.get_natural_order(2)] + # lazy init + [ptwp[key] for key in ptwp.get_natural_order(2)] ptwp.reconstruct() assert np.allclose(signal, ptwp[""].data[:, :32, :32, :32]) From 426691eaa8e35dbe97e18a0d260afbb1985f6545 Mon Sep 17 00:00:00 2001 From: Felix Blanke Date: Wed, 26 Jun 2024 18:09:19 +0200 Subject: [PATCH 15/21] Add partial init func --- src/ptwt/packets.py | 20 ++++++++++++++++++++ tests/test_packets.py | 14 +++++++------- 2 files changed, 27 insertions(+), 7 deletions(-) diff --git a/src/ptwt/packets.py b/src/ptwt/packets.py index c01825f5..24d91529 100644 --- a/src/ptwt/packets.py +++ b/src/ptwt/packets.py @@ -146,6 +146,16 @@ def transform( self.maxlevel = maxlevel return self + def initialize(self, keys: Iterable[str]) -> None: + """Initialize the wavelet packet tree partially. + + Args: + keys (Iterable[str]): An iterable yielding the keys of the + tree nodes to initialize. + """ + it = (self[key] for key in keys) + collections.deque(it, maxlen=0) + def reconstruct(self) -> WaveletPacket: """Recursively reconstruct the input starting from the leaf nodes. @@ -403,6 +413,16 @@ def transform( return self + def initialize(self, keys: Iterable[str]) -> None: + """Initialize the wavelet packet tree partially. + + Args: + keys (Iterable[str]): An iterable yielding the keys of the + tree nodes to initialize. + """ + it = (self[key] for key in keys) + collections.deque(it, maxlen=0) + def reconstruct(self) -> WaveletPacket2D: """Recursively reconstruct the input starting from the leaf nodes. diff --git a/tests/test_packets.py b/tests/test_packets.py index b41e50d4..9f9a1076 100644 --- a/tests/test_packets.py +++ b/tests/test_packets.py @@ -353,7 +353,7 @@ def test_partial_expansion_1d(wavelet_str: str, boundary: str) -> None: assert all(key in lazy_init_packet for key in partial_keys_1d) # init on partial keys - [lazy_init_packet[key] for key in partial_keys_1d] + lazy_init_packet.initialize(partial_keys_1d) with pytest.raises(AssertionError): assert all(key in lazy_init_packet for key in full_keys) @@ -361,7 +361,7 @@ def test_partial_expansion_1d(wavelet_str: str, boundary: str) -> None: assert all(key in lazy_init_packet for key in partial_keys_1d) # init on full keys - [lazy_init_packet[key] for key in full_keys] + lazy_init_packet.initialize(full_keys) assert all(key in lazy_init_packet for key in full_keys) @@ -392,7 +392,7 @@ def test_partial_expansion_2d(wavelet_str: str, boundary: str) -> None: assert all(key in lazy_init_packet for key in partial_keys_2d) # init on partial keys - [lazy_init_packet[key] for key in partial_keys_2d] + lazy_init_packet.initialize(partial_keys_2d) with pytest.raises(AssertionError): assert all(key in lazy_init_packet for key in full_keys) @@ -400,7 +400,7 @@ def test_partial_expansion_2d(wavelet_str: str, boundary: str) -> None: assert all(key in lazy_init_packet for key in partial_keys_2d) # init on full keys - [lazy_init_packet[key] for key in full_keys] + lazy_init_packet.initialize(full_keys) assert all(key in lazy_init_packet for key in full_keys) @@ -485,7 +485,7 @@ def test_inverse_packet_1d( ptwp.reconstruct() # lazy init - [ptwp[key] for key in ptwp.get_level(level)] + ptwp.initialize(ptwp.get_level(level)) wp[base_key * level].data *= 0 ptwp[base_key * level] *= 0 @@ -523,7 +523,7 @@ def test_inverse_packet_2d( ptwp.reconstruct() # lazy init - [ptwp[key] for key in ptwp.get_natural_order(level)] + ptwp.initialize(ptwp.get_natural_order(level)) ptwp[base_key * level] *= 0 wp.reconstruct(update=True) @@ -579,6 +579,6 @@ def test_separable_conv_packets_2d(axes: tuple[int, int]) -> None: ptwp.reconstruct() # lazy init - [ptwp[key] for key in ptwp.get_natural_order(2)] + ptwp.initialize(ptwp.get_natural_order(2)) ptwp.reconstruct() assert np.allclose(signal, ptwp[""].data[:, :32, :32, :32]) From 315611dce6bb0335d2552d8e8699f3ebdf5f98b3 Mon Sep 17 00:00:00 2001 From: Felix Blanke Date: Wed, 26 Jun 2024 18:09:34 +0200 Subject: [PATCH 16/21] Fix key lookup --- src/ptwt/packets.py | 32 ++++++++++++++++++++------------ 1 file changed, 20 insertions(+), 12 deletions(-) diff --git a/src/ptwt/packets.py b/src/ptwt/packets.py index 24d91529..895a7c94 100644 --- a/src/ptwt/packets.py +++ b/src/ptwt/packets.py @@ -112,6 +112,9 @@ def __init__( self._matrix_waverec_dict: dict[int, MatrixWaverec] = {} self.maxlevel: Optional[int] = None self.axis = axis + + self._filter_keys = {"a", "d"} + if data is not None: self.transform(data, maxlevel) else: @@ -188,12 +191,9 @@ def reconstruct(self) -> WaveletPacket: for node in self.get_level(level): # check if any children is not available # we need to check manually to avoid lazy init - def _test_key(key: str) -> None: - if key not in self: - raise KeyError(f"Key {key} not found") - - for child in ["a", "d"]: - _test_key(node + child) + for child in self._filter_keys: + if node + child not in self: + raise KeyError(f"Key {node + child} not found") data_a = self[node + "a"] data_d = self[node + "d"] @@ -315,6 +315,11 @@ def __getitem__(self, key: str) -> torch.Tensor: "The wavelet packet tree is not properly initialized. " "Run `transform` before accessing tree values." ) + elif key[-1] not in self._filter_keys: + raise ValueError( + f"Invalid key '{key}'. All chars in the key must be of the " + f"set {self._filter_keys}." + ) # calculate data from parent self._expand_node(key[:-1]) return super().__getitem__(key) @@ -375,6 +380,7 @@ def __init__( self.matrix_wavedec2_dict: dict[tuple[int, ...], MatrixWavedec2] = {} self.matrix_waverec2_dict: dict[tuple[int, ...], MatrixWaverec2] = {} self.axes = axes + self._filter_keys = {"a", "h", "v", "d"} self.maxlevel: Optional[int] = None if data is not None: @@ -442,12 +448,9 @@ def reconstruct(self) -> WaveletPacket2D: for node in WaveletPacket2D.get_natural_order(level): # check if any children is not available # we need to check manually to avoid lazy init - def _test_key(key: str) -> None: - if key not in self: - raise KeyError(f"Key {key} not found") - - for child in ["a", "h", "v", "d"]: - _test_key(node + child) + for child in self._filter_keys: + if node + child not in self: + raise KeyError(f"Key {node + child} not found") data_a = self[node + "a"] data_h = self[node + "h"] @@ -605,6 +608,11 @@ def __getitem__(self, key: str) -> torch.Tensor: "The wavelet packet tree is not properly initialized. " "Run `transform` before accessing tree values." ) + elif key[-1] not in self._filter_keys: + raise ValueError( + f"Invalid key '{key}'. All chars in the key must be of the " + f"set {self._filter_keys}." + ) # calculate data from parent self._expand_node(key[:-1]) From 2e2d2c7bb4099941269b22b6962b1b95021e69d4 Mon Sep 17 00:00:00 2001 From: Felix Blanke Date: Wed, 26 Jun 2024 18:09:43 +0200 Subject: [PATCH 17/21] Fix example var --- examples/deepfake_analysis/packet_plot.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/deepfake_analysis/packet_plot.py b/examples/deepfake_analysis/packet_plot.py index 83123f6e..dcc6baf7 100644 --- a/examples/deepfake_analysis/packet_plot.py +++ b/examples/deepfake_analysis/packet_plot.py @@ -68,7 +68,7 @@ def load_images(path: str) -> list: if __name__ == "__main__": freq_path = ptwt.WaveletPacket2D.get_freq_order(level=3) - frequency_path = ptwt.WaveletPacket2D.get_natural_order(level=3) + natural_path = ptwt.WaveletPacket2D.get_natural_order(level=3) print("Loading ffhq images:") ffhq_images = load_images("./ffhq_style_gan/source_data/A_ffhq") print("processing ffhq") From 41ca46152e2a301c40a515dd008536f652175f9b Mon Sep 17 00:00:00 2001 From: Felix Blanke Date: Wed, 26 Jun 2024 18:15:29 +0200 Subject: [PATCH 18/21] Fix test cases --- tests/test_packets.py | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/tests/test_packets.py b/tests/test_packets.py index 9f9a1076..6c3f16f9 100644 --- a/tests/test_packets.py +++ b/tests/test_packets.py @@ -535,10 +535,12 @@ def test_inverse_boundary_packet_1d() -> None: """Test the 2d boundary reconstruction code.""" signal = np.random.randn(1, 16) wp = pywt.WaveletPacket(signal, "haar", mode="zero", maxlevel=2) - ptwp = WaveletPacket(torch.from_numpy(signal), "haar", mode="boundary", maxlevel=2) wp["aa"].data *= 0 - ptwp["aa"].data *= 0 wp.reconstruct(update=True) + + ptwp = WaveletPacket(torch.from_numpy(signal), "haar", mode="boundary", maxlevel=2) + ptwp.initialize(["ad", "da", "dd"]) + ptwp["aa"] *= 0 ptwp.reconstruct() assert np.allclose(wp[""].data, ptwp[""].numpy()[:, :16]) @@ -550,14 +552,18 @@ def test_inverse_boundary_packet_2d() -> None: base_key = "h" wavelet = "haar" signal = np.random.randn(1, size[0], size[1]) + wp = pywt.WaveletPacket2D(signal, wavelet, mode="zero", maxlevel=level) + wp[base_key * level].data *= 0 + wp.reconstruct(update=True) + ptwp = WaveletPacket2D( torch.from_numpy(signal), wavelet, mode="boundary", maxlevel=level ) - wp[base_key * level].data *= 0 - ptwp[base_key * level].data *= 0 - wp.reconstruct(update=True) + ptwp.initialize(WaveletPacket2D.get_natural_order(level)) + ptwp[base_key * level] *= 0 ptwp.reconstruct() + assert np.allclose(wp[""].data, ptwp[""].numpy()[:, : size[0], : size[1]]) From 4cc8a7e36cd7fecab0181cebbdac1a41c12105be Mon Sep 17 00:00:00 2001 From: Felix Blanke Date: Wed, 26 Jun 2024 18:16:55 +0200 Subject: [PATCH 19/21] Add comment --- src/ptwt/packets.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/ptwt/packets.py b/src/ptwt/packets.py index 895a7c94..95f58d69 100644 --- a/src/ptwt/packets.py +++ b/src/ptwt/packets.py @@ -157,6 +157,7 @@ def initialize(self, keys: Iterable[str]) -> None: tree nodes to initialize. """ it = (self[key] for key in keys) + # exhaust iterator without storing all values collections.deque(it, maxlen=0) def reconstruct(self) -> WaveletPacket: @@ -427,6 +428,7 @@ def initialize(self, keys: Iterable[str]) -> None: tree nodes to initialize. """ it = (self[key] for key in keys) + # exhaust iterator without storing all values collections.deque(it, maxlen=0) def reconstruct(self) -> WaveletPacket2D: From 41aeaa218b1ead966ed28907f87e394aeedf1078 Mon Sep 17 00:00:00 2001 From: Moritz Wolter Date: Thu, 27 Jun 2024 15:51:40 +0200 Subject: [PATCH 20/21] add test. --- tests/test_packets.py | 21 +++++++++++++++++++++ 1 file changed, 21 insertions(+) diff --git a/tests/test_packets.py b/tests/test_packets.py index 6c3f16f9..d79ed15a 100644 --- a/tests/test_packets.py +++ b/tests/test_packets.py @@ -588,3 +588,24 @@ def test_separable_conv_packets_2d(axes: tuple[int, int]) -> None: ptwp.initialize(ptwp.get_natural_order(2)) ptwp.reconstruct() assert np.allclose(signal, ptwp[""].data[:, :32, :32, :32]) + + + +def test_partial_reconstruction() -> None: + + signal = np.random.randn(1, 16) + signal2 = np.cos(np.linspace(0, 2 * np.pi, 16)) + ptwp = WaveletPacket(torch.from_numpy(signal), "haar", + mode="reflect", maxlevel=2) + ptwp.initialize(["aa", "ad", "da", "dd"]) + + ptwp2 = WaveletPacket(torch.from_numpy(signal2), "haar", mode="reflect", maxlevel=2) + + # overwrite the first packet set. + ptwp["aa"] = ptwp2["aa"] + ptwp["ad"] = ptwp2["ad"] + ptwp["da"] = ptwp2["da"] + ptwp["dd"] = ptwp2["dd"] + ptwp.reconstruct() + + assert np.allclose(signal2, ptwp[""].numpy()[:16]) From bb79f6ebc7f76264b37db8e7a726b9c1ae705b0d Mon Sep 17 00:00:00 2001 From: Moritz Wolter Date: Thu, 27 Jun 2024 15:56:31 +0200 Subject: [PATCH 21/21] fix formatting. --- tests/test_packets.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/tests/test_packets.py b/tests/test_packets.py index d79ed15a..6392cf16 100644 --- a/tests/test_packets.py +++ b/tests/test_packets.py @@ -590,17 +590,15 @@ def test_separable_conv_packets_2d(axes: tuple[int, int]) -> None: assert np.allclose(signal, ptwp[""].data[:, :32, :32, :32]) - def test_partial_reconstruction() -> None: - + """Reconstruct a cosine wave from packet filters.""" signal = np.random.randn(1, 16) signal2 = np.cos(np.linspace(0, 2 * np.pi, 16)) - ptwp = WaveletPacket(torch.from_numpy(signal), "haar", - mode="reflect", maxlevel=2) + ptwp = WaveletPacket(torch.from_numpy(signal), "haar", mode="reflect", maxlevel=2) ptwp.initialize(["aa", "ad", "da", "dd"]) - + ptwp2 = WaveletPacket(torch.from_numpy(signal2), "haar", mode="reflect", maxlevel=2) - + # overwrite the first packet set. ptwp["aa"] = ptwp2["aa"] ptwp["ad"] = ptwp2["ad"]