Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve packets order interface #95

Merged
merged 10 commits into from
Jun 26, 2024
8 changes: 8 additions & 0 deletions src/ptwt/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,14 @@
Choose ``gramschmidt`` if ``qr`` runs out of memory.
"""

PacketNodeOrder = Literal["freq", "natural"]
"""
This is a type literal for the order of wavelet packet tree nodes.

- frequency order (``freq``)
- natural order (``natural``)
"""


class WaveletDetailTuple2d(NamedTuple):
"""Detail coefficients of a 2d wavelet transform for a given level.
Expand Down
63 changes: 58 additions & 5 deletions src/ptwt/packets.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from collections.abc import Sequence
from functools import partial
from itertools import product
from typing import TYPE_CHECKING, Callable, Optional, Union
from typing import TYPE_CHECKING, Callable, Literal, Optional, Union, overload

import numpy as np
import pywt
Expand All @@ -16,6 +16,7 @@
from .constants import (
ExtendedBoundaryMode,
OrthogonalizeMethod,
PacketNodeOrder,
WaveletCoeff2d,
WaveletCoeffNd,
WaveletDetailTuple2d,
Expand Down Expand Up @@ -203,18 +204,34 @@ def _get_waverec(
else:
return partial(waverec, wavelet=self.wavelet, axis=self.axis)

def get_level(self, level: int) -> list[str]:
"""Return the graycode-ordered paths to the filter tree nodes.
@staticmethod
def get_level(level: int, order: PacketNodeOrder = "freq") -> list[str]:
"""Return the paths to the filter tree nodes.

Args:
level (int): The depth of the tree.
order: The order the paths are in.
Choose from frequency order (``freq``) and
natural order (``natural``).
Defaults to ``freq``.

Returns:
A list with the paths to each node.

Raises:
ValueError: If `order` is neither ``freq`` nor ``natural``.
"""
return self._get_graycode_order(level)
if order == "freq":
return WaveletPacket._get_graycode_order(level)
elif order == "natural":
return ["".join(p) for p in product(["a", "d"], repeat=level)]
else:
raise ValueError(
f"Unsupported order '{order}'. Choose from 'freq' and 'natural'."
)

def _get_graycode_order(self, level: int, x: str = "a", y: str = "d") -> list[str]:
@staticmethod
def _get_graycode_order(level: int, x: str = "a", y: str = "d") -> list[str]:
graycode_order = [x, y]
for _ in range(level - 1):
graycode_order = [x + path for path in graycode_order] + [
Expand Down Expand Up @@ -514,6 +531,42 @@ def __getitem__(self, key: str) -> torch.Tensor:
)
return super().__getitem__(key)

@overload
@staticmethod
def get_level(level: int, order: Literal["freq"]) -> list[list[str]]: ...

@overload
@staticmethod
def get_level(level: int, order: Literal["natural"]) -> list[str]: ...

@staticmethod
def get_level(
level: int, order: PacketNodeOrder = "freq"
) -> Union[list[str], list[list[str]]]:
"""Return the paths to the filter tree nodes.

Args:
level (int): The depth of the tree.
order: The order the paths are in.
Choose from frequency order (``freq``) and
natural order (``natural``).
Defaults to ``freq``.

Returns:
A list with the paths to each node.

Raises:
ValueError: If `order` is neither ``freq`` nor ``natural``.
"""
if order == "freq":
return WaveletPacket2D.get_freq_order(level)
elif order == "natural":
return WaveletPacket2D.get_natural_order(level)
else:
raise ValueError(
f"Unsupported order '{order}'. Choose from 'freq' and 'natural'."
)

@staticmethod
def get_natural_order(level: int) -> list[str]:
"""Get the natural ordering for a given decomposition level.
Expand Down
40 changes: 40 additions & 0 deletions tests/test_packets.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,6 +240,26 @@ def test_boundary_matrix_packets1(
)


@pytest.mark.parametrize("level", [1, 2, 3, 4])
@pytest.mark.parametrize("wavelet_str", ["db2"])
@pytest.mark.parametrize("pywt_boundary", ["zero"])
@pytest.mark.parametrize("order", ["freq", "natural"])
def test_order_1d(level: int, wavelet_str: str, pywt_boundary: str, order: str) -> None:
"""Test the packets in natural order."""
data = np.random.rand(2, 256)
wp_tree = pywt.WaveletPacket(
data=data,
wavelet=wavelet_str,
mode=pywt_boundary,
)
# Get the full decomposition
order_pywt = wp_tree.get_level(level, order)
order_ptwt = WaveletPacket.get_level(level, order)

for order_el, order_path in zip(order_pywt, order_ptwt):
assert order_el.path == order_path


@pytest.mark.parametrize("level", [1, 2, 3, 4])
@pytest.mark.parametrize("wavelet_str", ["db2"])
@pytest.mark.parametrize("pywt_boundary", ["zero"])
Expand All @@ -261,6 +281,26 @@ def test_freq_order_2d(level: int, wavelet_str: str, pywt_boundary: str) -> None
assert order_el.path == order_path


@pytest.mark.parametrize("level", [1, 2, 3, 4])
@pytest.mark.parametrize("wavelet_str", ["db2"])
@pytest.mark.parametrize("pywt_boundary", ["zero"])
def test_natural_order_2d(level: int, wavelet_str: str, pywt_boundary: str) -> None:
"""Test the packets in natural order."""
face = datasets.face()
wavelet = pywt.Wavelet(wavelet_str)
wp_tree = pywt.WaveletPacket2D(
data=np.mean(face, axis=-1).astype(np.float64),
wavelet=wavelet,
mode=pywt_boundary,
)
# Get the full decomposition
order_pywt = wp_tree.get_level(level, "natural")
order_ptwt = WaveletPacket2D.get_natural_order(level)

for order_el, order_path in zip(order_pywt, order_ptwt):
assert order_el.path == order_path


def test_packet_harbo_lvl3() -> None:
"""From Jensen, La Cour-Harbo, Rippels in Mathematics, Chapter 8 (page 89)."""
data = np.array([56.0, 40.0, 8.0, 24.0, 48.0, 48.0, 40.0, 16.0])
Expand Down
Loading