From 5f740e697ffffeefdd847e487a68a1a6cc235fec Mon Sep 17 00:00:00 2001 From: Charles Tapley Hoyt Date: Sun, 28 Jan 2024 23:52:44 +0100 Subject: [PATCH] More typing --- tests/test_convolution_fwt.py | 36 +++++++++++++++++--------------- tests/test_jit.py | 8 +++---- tests/test_packets.py | 3 ++- tests/test_separable_conv_fwt.py | 16 +++++++++----- tests/test_util.py | 2 +- 5 files changed, 37 insertions(+), 28 deletions(-) diff --git a/tests/test_convolution_fwt.py b/tests/test_convolution_fwt.py index ff25362c..b50e2ff8 100644 --- a/tests/test_convolution_fwt.py +++ b/tests/test_convolution_fwt.py @@ -1,6 +1,6 @@ """Test the conv-fwt code.""" -from typing import List, Optional, Sequence +from typing import List, Optional, Sequence, Tuple # Written by moritz ( @ wolter.tech ) in 2021 import numpy as np @@ -140,7 +140,7 @@ def test_1d_multibatch(level: Optional[int], shape: Sequence[int]) -> None: @pytest.mark.parametrize("axis", [-1, 0, 1, 2]) -def test_1d_axis_arg(axis): +def test_1d_axis_arg(axis: int): """Ensure the axis argument works as expected.""" data = torch.randn([16, 16, 16], dtype=torch.float64) @@ -188,7 +188,7 @@ def test_2d_db2_lvl1() -> None: assert np.allclose(rec.numpy().squeeze(), face) -def test_2d_haar_multi(): +def test_2d_haar_multi() -> None: """Test a 2d-db2 wavelet level 5 conv-fwt.""" # multi level haar - 2d face = np.transpose( @@ -205,7 +205,7 @@ def test_2d_haar_multi(): assert np.allclose(rec, face) -def test_outer(): +def test_outer() -> None: """Test the outer-product implementation.""" a = torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0]) b = torch.tensor([6.0, 7.0, 8.0, 9.0, 10.0]) @@ -223,7 +223,9 @@ def test_outer(): @pytest.mark.parametrize( "mode", ["reflect", "zero", "constant", "periodic", "symmetric"] ) -def test_2d_wavedec_rec(wavelet_str, level, size, mode): +def test_2d_wavedec_rec( + wavelet_str: str, level: Optional[int], size: Tuple[int, int], mode: BoundaryMode +): """Ensure pywt.wavedec2 and ptwt.wavedec2 produce the same coefficients. wavedec2 and waverec2 must invert each other. @@ -257,7 +259,9 @@ def test_2d_wavedec_rec(wavelet_str, level, size, mode): ) @pytest.mark.parametrize("level", [1, None]) @pytest.mark.parametrize("wavelet", ["haar", "sym3"]) -def test_input_4d(size, level, wavelet): +def test_input_4d( + size: Tuple[int, int, int, int], level: Optional[str], wavelet: str +) -> None: """Test the error for 4d inputs to wavedec2.""" data = torch.randn(*size).type(torch.float64) @@ -284,20 +288,20 @@ def test_input_4d(size, level, wavelet): @pytest.mark.parametrize("padding_str", ["invalid_padding_name"]) -def test_incorrect_padding(padding_str): +def test_incorrect_padding(padding_str: BoundaryMode) -> None: """Test expected errors for an invalid padding name.""" with pytest.raises(ValueError): _ = _translate_boundary_strings(padding_str) -def test_input_1d_dimension_error(): +def test_input_1d_dimension_error() -> None: """Test the error for 1d inputs to wavedec2.""" with pytest.raises(ValueError): data = torch.randn(50) wavedec2(data, "haar", level=4) -def _compare_coeffs(ptwt_res, pywt_res): +def _compare_coeffs(ptwt_res, pywt_res) -> List[bool]: """Compare coefficient lists. Args: @@ -311,10 +315,8 @@ def _compare_coeffs(ptwt_res, pywt_res): for ptwtcs, pywtcs in zip(ptwt_res, pywt_res): if isinstance(ptwtcs, tuple): test_list.extend( - tuple( - np.allclose(ptwtc.numpy(), pywtc) - for ptwtc, pywtc in zip(ptwtcs, pywtcs) - ) + np.allclose(ptwtc.numpy(), pywtc) + for ptwtc, pywtc in zip(ptwtcs, pywtcs) ) else: test_list.append(np.allclose(ptwtcs.numpy(), pywtcs)) @@ -325,7 +327,7 @@ def _compare_coeffs(ptwt_res, pywt_res): @pytest.mark.parametrize( "size", [(50, 20, 128, 128), (8, 49, 21, 128, 128), (6, 4, 4, 5, 64, 64)] ) -def test_2d_multidim_input(size): +def test_2d_multidim_input(size: Tuple[int, ...]) -> None: """Test the error for multi-dimensional inputs to wavedec2.""" data = torch.randn(*size, dtype=torch.float64) wavelet = "db2" @@ -347,7 +349,7 @@ def test_2d_multidim_input(size): @pytest.mark.slow @pytest.mark.parametrize("axes", [(-2, -1), (-1, -2), (-3, -2), (0, 1), (1, 0)]) -def test_2d_axis_argument(axes): +def test_2d_axis_argument(axes: Tuple[int, int]) -> None: """Ensure the axes argument works as expected.""" data = torch.randn([32, 32, 32, 32], dtype=torch.float64) @@ -365,14 +367,14 @@ def test_2d_axis_argument(axes): ) -def test_2d_axis_error_axes_count(): +def test_2d_axis_error_axes_count() -> None: """Check the error for too many axes.""" with pytest.raises(ValueError): data = torch.randn([32, 32, 32, 32], dtype=torch.float64) wavedec2(data, "haar", level=1, axes=(1, 2, 3)) -def test_2d_axis_error_axes_repetition(): +def test_2d_axis_error_axes_repetition() -> None: """Check the error for axes repetition.""" with pytest.raises(ValueError): data = torch.randn([32, 32, 32, 32], dtype=torch.float64) diff --git a/tests/test_jit.py b/tests/test_jit.py index e547343a..d403ea34 100644 --- a/tests/test_jit.py +++ b/tests/test_jit.py @@ -1,6 +1,6 @@ """Ensure pytorch's torch.jit.trace feature works properly.""" -from typing import NamedTuple +from typing import NamedTuple, List import numpy as np import pytest @@ -82,7 +82,7 @@ def _to_jit_wavedec_2(data, wavelet): return coeff2 -def _to_jit_waverec_2(data, wavelet): +def _to_jit_waverec_2(data, wavelet) -> torch.Tensor: """Undo the stacking from the jit wavedec2 wrapper.""" d_unstack = [data[0]] for c in data[1:]: @@ -113,7 +113,7 @@ def test_conv_fwt_jit_2d() -> None: assert np.allclose(rec.squeeze(1).numpy(), data.numpy(), atol=1e-7) -def _to_jit_wavedec_3(data, wavelet): +def _to_jit_wavedec_3(data: torch.Tensor, wavelet: str) -> List[torch.Tensor]: """Ensure uniform datatypes in lists for the tracer. Going from List[Union[torch.Tensor, Dict[str, torch.Tensor]]] to List[torch.Tensor] @@ -165,7 +165,7 @@ def test_conv_fwt_jit_3d() -> None: assert np.allclose(rec.squeeze(1).numpy(), data.numpy(), atol=1e-7) -def _to_jit_cwt(sig): +def _to_jit_cwt(sig: torch.Tensor) -> torch.Tensor: widths = torch.arange(1, 31) wavelet = _ShannonWavelet("shan0.1-0.4") sampling_period = (4 / 800) * np.pi diff --git a/tests/test_packets.py b/tests/test_packets.py index 5fd86601..b29ac102 100644 --- a/tests/test_packets.py +++ b/tests/test_packets.py @@ -3,6 +3,7 @@ # Created on Fri Apr 6 2021 by moritz (wolter@cs.uni-bonn.de) from itertools import product +from typing import Optional import numpy as np import pytest @@ -131,7 +132,7 @@ def _compare_trees2( @pytest.mark.parametrize("transform_mode", [False, True]) @pytest.mark.parametrize("multiple_transforms", [False, True]) def test_2d_packets( - max_lev, wavelet_str, boundary, batch_size, transform_mode, multiple_transforms + max_lev: Optional[int], wavelet_str: str, boundary, batch_size, transform_mode, multiple_transforms ) -> None: """Ensure pywt and ptwt produce equivalent wavelet 2d packet trees.""" _compare_trees2( diff --git a/tests/test_separable_conv_fwt.py b/tests/test_separable_conv_fwt.py index b81b7e51..cb7329aa 100644 --- a/tests/test_separable_conv_fwt.py +++ b/tests/test_separable_conv_fwt.py @@ -1,5 +1,7 @@ """Separable transform test code.""" +from typing import Optional, Tuple, Sequence + import numpy as np import pytest import pywt @@ -21,7 +23,7 @@ @pytest.mark.parametrize( "shape", ((12, 12), (24, 12, 12), (12, 24, 12), (12, 12, 12, 12)) ) -def test_separable_conv(shape, level) -> None: +def test_separable_conv(shape: Sequence[int], level: int) -> None: """Test the separable transforms.""" data = np.random.randint(0, 9, shape) @@ -62,7 +64,7 @@ def test_separable_conv(shape, level) -> None: @pytest.mark.parametrize("shape", [(5, 64, 64), (5, 65, 65), (5, 29, 29)]) @pytest.mark.parametrize("wavelet", ["haar", "db3", "sym5"]) -def test_example_fs2d(shape, wavelet) -> None: +def test_example_fs2d(shape: Sequence[int], wavelet: str) -> None: """Test 2d fully separable padding.""" data = torch.randn(*shape).type(torch.float64) coeff = fswavedec2(data, wavelet, level=2) @@ -72,7 +74,7 @@ def test_example_fs2d(shape, wavelet) -> None: @pytest.mark.parametrize("shape", [(5, 64, 64, 64), (5, 65, 65, 65), (5, 29, 29, 29)]) @pytest.mark.parametrize("wavelet", ["haar", "db3", "sym5"]) -def test_example_fs3d(shape, wavelet) -> None: +def test_example_fs3d(shape: Sequence[int], wavelet: str) -> None: """Test 3d fully separable padding.""" data = torch.randn(*shape).type(torch.float64) coeff = fswavedec3(data, wavelet, level=2) @@ -87,7 +89,9 @@ def test_example_fs3d(shape, wavelet) -> None: "shape", [[1, 64, 128, 128], [1, 3, 64, 64, 64], [2, 1, 64, 64, 64]] ) @pytest.mark.parametrize("axes", [(-2, -1), (-1, -2), (2, 3), (3, 2)]) -def test_conv_mm_2d(level, shape, axes) -> None: +def test_conv_mm_2d( + level: Optional[int], shape: Sequence[int], axes: Tuple[int, int] +) -> None: """Compare mm and conv fully separable results.""" data = torch.randn(*shape).type(torch.float64) fs_conv_coeff = fswavedec2(data, "haar", level=level, axes=axes) @@ -114,7 +118,9 @@ def test_conv_mm_2d(level, shape, axes) -> None: @pytest.mark.parametrize("level", [1, 2, 3, None]) @pytest.mark.parametrize("axes", [(-3, -2, -1), (-1, -2, -3), (2, 3, 1)]) @pytest.mark.parametrize("shape", [(5, 64, 128, 256)]) -def test_conv_mm_3d(level, axes, shape) -> None: +def test_conv_mm_3d( + level: Optional[int], axes: Tuple[int, int, int], shape: Tuple[int, ...] +) -> None: """Compare mm and conv 3d fully separable results.""" data = torch.randn(*shape).type(torch.float64) fs_conv_coeff = fswavedec3(data, "haar", level=level, axes=axes) diff --git a/tests/test_util.py b/tests/test_util.py index ad37e2e5..7109ac57 100644 --- a/tests/test_util.py +++ b/tests/test_util.py @@ -18,7 +18,7 @@ class _MyHaarFilterBank: @property - def filter_bank(self) -> Tuple[list, list, list, list]: + def filter_bank(self) -> Tuple[List[float], List[float], List[float], List[float]]: """Unscaled Haar wavelet filters.""" return ( [1 / 2, 1 / 2.0],