Skip to content

Commit

Permalink
Add tests
Browse files Browse the repository at this point in the history
  • Loading branch information
felixblanke committed Jun 26, 2024
1 parent dd586c6 commit 5862bb5
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 7 deletions.
11 changes: 8 additions & 3 deletions tests/test_matrix_fwt.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import pywt
import torch

from ptwt.constants import OrthogonalizeMethod
from ptwt.constants import BoundaryMode, OrthogonalizeMethod
from ptwt.matmul_transform import (
MatrixWavedec,
MatrixWaverec,
Expand Down Expand Up @@ -71,12 +71,17 @@ def test_fwt_ifwt_mackey_haar_cuda() -> None:
@pytest.mark.parametrize("level", [1, 2, 3, 4, None])
@pytest.mark.parametrize("wavelet", ["db2", "db3", "db4", "sym5"])
@pytest.mark.parametrize("size", [[2, 256], [2, 3, 256], [1, 1, 128]])
def test_1d_matrix_fwt_ifwt(level: int, wavelet: str, size: list[int]) -> None:
@pytest.mark.parametrize(
"mode", ["reflect", "zero", "constant", "periodic", "symmetric"]
)
def test_1d_matrix_fwt_ifwt(
level: int, wavelet: str, size: list[int], mode: BoundaryMode
) -> None:
"""Test multiple wavelets and levels for a long signal."""
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
wavelet = pywt.Wavelet(wavelet)
pt_data = torch.randn(*size, device=device).type(torch.float64)
matrix_wavedec = MatrixWavedec(wavelet, level)
matrix_wavedec = MatrixWavedec(wavelet, level, odd_coeff_padding_mode=mode)
coeffs_mat_max = matrix_wavedec(pt_data)
matrix_waverec = MatrixWaverec(wavelet)
reconstructed_data = matrix_waverec(coeffs_mat_max)
Expand Down
14 changes: 12 additions & 2 deletions tests/test_matrix_fwt_2.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import scipy.signal
import torch

from ptwt.constants import BoundaryMode
from ptwt.conv_transform import _flatten_2d_coeff_lst
from ptwt.matmul_transform import BaseMatrixWaveDec, MatrixWavedec, MatrixWaverec
from ptwt.matmul_transform_2 import (
Expand Down Expand Up @@ -86,15 +87,24 @@ def test_matrix_analysis_fwt_2d_haar(size: tuple[int, int], level: int) -> None:
)
@pytest.mark.parametrize("level", [1, 2, 3, None])
@pytest.mark.parametrize("separable", [False, True])
@pytest.mark.parametrize(
"mode", ["reflect", "zero", "constant", "periodic", "symmetric"]
)
def test_boundary_matrix_fwt_2d(
wavelet_str: str, size: tuple[int, int], level: int, separable: bool
wavelet_str: str,
size: tuple[int, int],
level: int,
separable: bool,
mode: BoundaryMode,
) -> None:
"""Ensure the boundary matrix fwt is invertable."""
face = np.mean(
scipy.datasets.face()[256 : (256 + size[0]), 256 : (256 + size[1])], -1
).astype(np.float64)
wavelet = pywt.Wavelet(wavelet_str)
matrixfwt = MatrixWavedec2(wavelet, level=level, separable=separable)
matrixfwt = MatrixWavedec2(
wavelet, level=level, separable=separable, odd_coeff_padding_mode=mode
)
mat_coeff = matrixfwt(torch.from_numpy(face))
matrixifwt = MatrixWaverec2(wavelet, separable=separable)
reconstruction = matrixifwt(mat_coeff).squeeze(0)
Expand Down
8 changes: 6 additions & 2 deletions tests/test_matrix_fwt_3.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import pywt
import torch

from ptwt.constants import BoundaryMode
from ptwt.matmul_transform import construct_boundary_a
from ptwt.matmul_transform_3 import MatrixWavedec3, MatrixWaverec3
from ptwt.sparse_math import _batch_dim_mm
Expand Down Expand Up @@ -75,13 +76,16 @@ def test_boundary_wavedec3_level1_haar(shape: tuple[int, int, int]) -> None:
@pytest.mark.parametrize(
"shape", [(31, 32, 33), (63, 35, 32), (32, 62, 31), (32, 32, 64)]
)
@pytest.mark.parametrize(
"mode", ["reflect", "zero", "constant", "periodic", "symmetric"]
)
def test_boundary_wavedec3_inverse(
level: Optional[int], shape: tuple[int, int, int]
level: Optional[int], shape: tuple[int, int, int], mode: BoundaryMode
) -> None:
"""Test the 3d matrix wavedec and the padding for odd axes."""
batch_size = 1
test_data = torch.rand(batch_size, shape[0], shape[1], shape[2]).type(torch.float64)
ptwtres = MatrixWavedec3("haar", level)(test_data)
ptwtres = MatrixWavedec3("haar", level, odd_coeff_padding_mode=mode)(test_data)
rec = MatrixWaverec3("haar")(ptwtres)
assert np.allclose(
test_data.numpy(), rec[:, : shape[0], : shape[1], : shape[2]].numpy()
Expand Down

0 comments on commit 5862bb5

Please sign in to comment.