diff --git a/tests/test_matrix_fwt.py b/tests/test_matrix_fwt.py index b83b03df..d95f2081 100644 --- a/tests/test_matrix_fwt.py +++ b/tests/test_matrix_fwt.py @@ -10,7 +10,7 @@ import pywt import torch -from ptwt.constants import OrthogonalizeMethod +from ptwt.constants import BoundaryMode, OrthogonalizeMethod from ptwt.matmul_transform import ( MatrixWavedec, MatrixWaverec, @@ -71,12 +71,17 @@ def test_fwt_ifwt_mackey_haar_cuda() -> None: @pytest.mark.parametrize("level", [1, 2, 3, 4, None]) @pytest.mark.parametrize("wavelet", ["db2", "db3", "db4", "sym5"]) @pytest.mark.parametrize("size", [[2, 256], [2, 3, 256], [1, 1, 128]]) -def test_1d_matrix_fwt_ifwt(level: int, wavelet: str, size: list[int]) -> None: +@pytest.mark.parametrize( + "mode", ["reflect", "zero", "constant", "periodic", "symmetric"] +) +def test_1d_matrix_fwt_ifwt( + level: int, wavelet: str, size: list[int], mode: BoundaryMode +) -> None: """Test multiple wavelets and levels for a long signal.""" device = torch.device("cuda" if torch.cuda.is_available() else "cpu") wavelet = pywt.Wavelet(wavelet) pt_data = torch.randn(*size, device=device).type(torch.float64) - matrix_wavedec = MatrixWavedec(wavelet, level) + matrix_wavedec = MatrixWavedec(wavelet, level, odd_coeff_padding_mode=mode) coeffs_mat_max = matrix_wavedec(pt_data) matrix_waverec = MatrixWaverec(wavelet) reconstructed_data = matrix_waverec(coeffs_mat_max) diff --git a/tests/test_matrix_fwt_2.py b/tests/test_matrix_fwt_2.py index 92a07197..af04fffb 100644 --- a/tests/test_matrix_fwt_2.py +++ b/tests/test_matrix_fwt_2.py @@ -8,6 +8,7 @@ import scipy.signal import torch +from ptwt.constants import BoundaryMode from ptwt.conv_transform import _flatten_2d_coeff_lst from ptwt.matmul_transform import BaseMatrixWaveDec, MatrixWavedec, MatrixWaverec from ptwt.matmul_transform_2 import ( @@ -86,15 +87,24 @@ def test_matrix_analysis_fwt_2d_haar(size: tuple[int, int], level: int) -> None: ) @pytest.mark.parametrize("level", [1, 2, 3, None]) @pytest.mark.parametrize("separable", [False, True]) +@pytest.mark.parametrize( + "mode", ["reflect", "zero", "constant", "periodic", "symmetric"] +) def test_boundary_matrix_fwt_2d( - wavelet_str: str, size: tuple[int, int], level: int, separable: bool + wavelet_str: str, + size: tuple[int, int], + level: int, + separable: bool, + mode: BoundaryMode, ) -> None: """Ensure the boundary matrix fwt is invertable.""" face = np.mean( scipy.datasets.face()[256 : (256 + size[0]), 256 : (256 + size[1])], -1 ).astype(np.float64) wavelet = pywt.Wavelet(wavelet_str) - matrixfwt = MatrixWavedec2(wavelet, level=level, separable=separable) + matrixfwt = MatrixWavedec2( + wavelet, level=level, separable=separable, odd_coeff_padding_mode=mode + ) mat_coeff = matrixfwt(torch.from_numpy(face)) matrixifwt = MatrixWaverec2(wavelet, separable=separable) reconstruction = matrixifwt(mat_coeff).squeeze(0) diff --git a/tests/test_matrix_fwt_3.py b/tests/test_matrix_fwt_3.py index 2856e591..d07848a2 100644 --- a/tests/test_matrix_fwt_3.py +++ b/tests/test_matrix_fwt_3.py @@ -7,6 +7,7 @@ import pywt import torch +from ptwt.constants import BoundaryMode from ptwt.matmul_transform import construct_boundary_a from ptwt.matmul_transform_3 import MatrixWavedec3, MatrixWaverec3 from ptwt.sparse_math import _batch_dim_mm @@ -75,13 +76,16 @@ def test_boundary_wavedec3_level1_haar(shape: tuple[int, int, int]) -> None: @pytest.mark.parametrize( "shape", [(31, 32, 33), (63, 35, 32), (32, 62, 31), (32, 32, 64)] ) +@pytest.mark.parametrize( + "mode", ["reflect", "zero", "constant", "periodic", "symmetric"] +) def test_boundary_wavedec3_inverse( - level: Optional[int], shape: tuple[int, int, int] + level: Optional[int], shape: tuple[int, int, int], mode: BoundaryMode ) -> None: """Test the 3d matrix wavedec and the padding for odd axes.""" batch_size = 1 test_data = torch.rand(batch_size, shape[0], shape[1], shape[2]).type(torch.float64) - ptwtres = MatrixWavedec3("haar", level)(test_data) + ptwtres = MatrixWavedec3("haar", level, odd_coeff_padding_mode=mode)(test_data) rec = MatrixWaverec3("haar")(ptwtres) assert np.allclose( test_data.numpy(), rec[:, : shape[0], : shape[1], : shape[2]].numpy()