Skip to content

Commit

Permalink
More typing
Browse files Browse the repository at this point in the history
  • Loading branch information
cthoyt committed Jan 28, 2024
1 parent f2d943e commit 5f740e6
Show file tree
Hide file tree
Showing 5 changed files with 37 additions and 28 deletions.
36 changes: 19 additions & 17 deletions tests/test_convolution_fwt.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""Test the conv-fwt code."""

from typing import List, Optional, Sequence
from typing import List, Optional, Sequence, Tuple

# Written by moritz ( @ wolter.tech ) in 2021
import numpy as np
Expand Down Expand Up @@ -140,7 +140,7 @@ def test_1d_multibatch(level: Optional[int], shape: Sequence[int]) -> None:


@pytest.mark.parametrize("axis", [-1, 0, 1, 2])
def test_1d_axis_arg(axis):
def test_1d_axis_arg(axis: int):
"""Ensure the axis argument works as expected."""
data = torch.randn([16, 16, 16], dtype=torch.float64)

Expand Down Expand Up @@ -188,7 +188,7 @@ def test_2d_db2_lvl1() -> None:
assert np.allclose(rec.numpy().squeeze(), face)


def test_2d_haar_multi():
def test_2d_haar_multi() -> None:
"""Test a 2d-db2 wavelet level 5 conv-fwt."""
# multi level haar - 2d
face = np.transpose(
Expand All @@ -205,7 +205,7 @@ def test_2d_haar_multi():
assert np.allclose(rec, face)


def test_outer():
def test_outer() -> None:
"""Test the outer-product implementation."""
a = torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0])
b = torch.tensor([6.0, 7.0, 8.0, 9.0, 10.0])
Expand All @@ -223,7 +223,9 @@ def test_outer():
@pytest.mark.parametrize(
"mode", ["reflect", "zero", "constant", "periodic", "symmetric"]
)
def test_2d_wavedec_rec(wavelet_str, level, size, mode):
def test_2d_wavedec_rec(
wavelet_str: str, level: Optional[int], size: Tuple[int, int], mode: BoundaryMode
):
"""Ensure pywt.wavedec2 and ptwt.wavedec2 produce the same coefficients.
wavedec2 and waverec2 must invert each other.
Expand Down Expand Up @@ -257,7 +259,9 @@ def test_2d_wavedec_rec(wavelet_str, level, size, mode):
)
@pytest.mark.parametrize("level", [1, None])
@pytest.mark.parametrize("wavelet", ["haar", "sym3"])
def test_input_4d(size, level, wavelet):
def test_input_4d(
size: Tuple[int, int, int, int], level: Optional[str], wavelet: str
) -> None:
"""Test the error for 4d inputs to wavedec2."""
data = torch.randn(*size).type(torch.float64)

Expand All @@ -284,20 +288,20 @@ def test_input_4d(size, level, wavelet):


@pytest.mark.parametrize("padding_str", ["invalid_padding_name"])
def test_incorrect_padding(padding_str):
def test_incorrect_padding(padding_str: BoundaryMode) -> None:
"""Test expected errors for an invalid padding name."""
with pytest.raises(ValueError):
_ = _translate_boundary_strings(padding_str)


def test_input_1d_dimension_error():
def test_input_1d_dimension_error() -> None:
"""Test the error for 1d inputs to wavedec2."""
with pytest.raises(ValueError):
data = torch.randn(50)
wavedec2(data, "haar", level=4)


def _compare_coeffs(ptwt_res, pywt_res):
def _compare_coeffs(ptwt_res, pywt_res) -> List[bool]:
"""Compare coefficient lists.
Args:
Expand All @@ -311,10 +315,8 @@ def _compare_coeffs(ptwt_res, pywt_res):
for ptwtcs, pywtcs in zip(ptwt_res, pywt_res):
if isinstance(ptwtcs, tuple):
test_list.extend(
tuple(
np.allclose(ptwtc.numpy(), pywtc)
for ptwtc, pywtc in zip(ptwtcs, pywtcs)
)
np.allclose(ptwtc.numpy(), pywtc)
for ptwtc, pywtc in zip(ptwtcs, pywtcs)
)
else:
test_list.append(np.allclose(ptwtcs.numpy(), pywtcs))
Expand All @@ -325,7 +327,7 @@ def _compare_coeffs(ptwt_res, pywt_res):
@pytest.mark.parametrize(
"size", [(50, 20, 128, 128), (8, 49, 21, 128, 128), (6, 4, 4, 5, 64, 64)]
)
def test_2d_multidim_input(size):
def test_2d_multidim_input(size: Tuple[int, ...]) -> None:
"""Test the error for multi-dimensional inputs to wavedec2."""
data = torch.randn(*size, dtype=torch.float64)
wavelet = "db2"
Expand All @@ -347,7 +349,7 @@ def test_2d_multidim_input(size):

@pytest.mark.slow
@pytest.mark.parametrize("axes", [(-2, -1), (-1, -2), (-3, -2), (0, 1), (1, 0)])
def test_2d_axis_argument(axes):
def test_2d_axis_argument(axes: Tuple[int, int]) -> None:
"""Ensure the axes argument works as expected."""
data = torch.randn([32, 32, 32, 32], dtype=torch.float64)

Expand All @@ -365,14 +367,14 @@ def test_2d_axis_argument(axes):
)


def test_2d_axis_error_axes_count():
def test_2d_axis_error_axes_count() -> None:
"""Check the error for too many axes."""
with pytest.raises(ValueError):
data = torch.randn([32, 32, 32, 32], dtype=torch.float64)
wavedec2(data, "haar", level=1, axes=(1, 2, 3))


def test_2d_axis_error_axes_repetition():
def test_2d_axis_error_axes_repetition() -> None:
"""Check the error for axes repetition."""
with pytest.raises(ValueError):
data = torch.randn([32, 32, 32, 32], dtype=torch.float64)
Expand Down
8 changes: 4 additions & 4 deletions tests/test_jit.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""Ensure pytorch's torch.jit.trace feature works properly."""

from typing import NamedTuple
from typing import NamedTuple, List

import numpy as np
import pytest
Expand Down Expand Up @@ -82,7 +82,7 @@ def _to_jit_wavedec_2(data, wavelet):
return coeff2


def _to_jit_waverec_2(data, wavelet):
def _to_jit_waverec_2(data, wavelet) -> torch.Tensor:
"""Undo the stacking from the jit wavedec2 wrapper."""
d_unstack = [data[0]]
for c in data[1:]:
Expand Down Expand Up @@ -113,7 +113,7 @@ def test_conv_fwt_jit_2d() -> None:
assert np.allclose(rec.squeeze(1).numpy(), data.numpy(), atol=1e-7)


def _to_jit_wavedec_3(data, wavelet):
def _to_jit_wavedec_3(data: torch.Tensor, wavelet: str) -> List[torch.Tensor]:
"""Ensure uniform datatypes in lists for the tracer.
Going from List[Union[torch.Tensor, Dict[str, torch.Tensor]]] to List[torch.Tensor]
Expand Down Expand Up @@ -165,7 +165,7 @@ def test_conv_fwt_jit_3d() -> None:
assert np.allclose(rec.squeeze(1).numpy(), data.numpy(), atol=1e-7)


def _to_jit_cwt(sig):
def _to_jit_cwt(sig: torch.Tensor) -> torch.Tensor:
widths = torch.arange(1, 31)
wavelet = _ShannonWavelet("shan0.1-0.4")
sampling_period = (4 / 800) * np.pi
Expand Down
3 changes: 2 additions & 1 deletion tests/test_packets.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
# Created on Fri Apr 6 2021 by moritz (wolter@cs.uni-bonn.de)

from itertools import product
from typing import Optional

import numpy as np
import pytest
Expand Down Expand Up @@ -131,7 +132,7 @@ def _compare_trees2(
@pytest.mark.parametrize("transform_mode", [False, True])
@pytest.mark.parametrize("multiple_transforms", [False, True])
def test_2d_packets(
max_lev, wavelet_str, boundary, batch_size, transform_mode, multiple_transforms
max_lev: Optional[int], wavelet_str: str, boundary, batch_size, transform_mode, multiple_transforms
) -> None:
"""Ensure pywt and ptwt produce equivalent wavelet 2d packet trees."""
_compare_trees2(
Expand Down
16 changes: 11 additions & 5 deletions tests/test_separable_conv_fwt.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
"""Separable transform test code."""

from typing import Optional, Tuple, Sequence

import numpy as np
import pytest
import pywt
Expand All @@ -21,7 +23,7 @@
@pytest.mark.parametrize(
"shape", ((12, 12), (24, 12, 12), (12, 24, 12), (12, 12, 12, 12))
)
def test_separable_conv(shape, level) -> None:
def test_separable_conv(shape: Sequence[int], level: int) -> None:
"""Test the separable transforms."""
data = np.random.randint(0, 9, shape)

Expand Down Expand Up @@ -62,7 +64,7 @@ def test_separable_conv(shape, level) -> None:

@pytest.mark.parametrize("shape", [(5, 64, 64), (5, 65, 65), (5, 29, 29)])
@pytest.mark.parametrize("wavelet", ["haar", "db3", "sym5"])
def test_example_fs2d(shape, wavelet) -> None:
def test_example_fs2d(shape: Sequence[int], wavelet: str) -> None:
"""Test 2d fully separable padding."""
data = torch.randn(*shape).type(torch.float64)
coeff = fswavedec2(data, wavelet, level=2)
Expand All @@ -72,7 +74,7 @@ def test_example_fs2d(shape, wavelet) -> None:

@pytest.mark.parametrize("shape", [(5, 64, 64, 64), (5, 65, 65, 65), (5, 29, 29, 29)])
@pytest.mark.parametrize("wavelet", ["haar", "db3", "sym5"])
def test_example_fs3d(shape, wavelet) -> None:
def test_example_fs3d(shape: Sequence[int], wavelet: str) -> None:
"""Test 3d fully separable padding."""
data = torch.randn(*shape).type(torch.float64)
coeff = fswavedec3(data, wavelet, level=2)
Expand All @@ -87,7 +89,9 @@ def test_example_fs3d(shape, wavelet) -> None:
"shape", [[1, 64, 128, 128], [1, 3, 64, 64, 64], [2, 1, 64, 64, 64]]
)
@pytest.mark.parametrize("axes", [(-2, -1), (-1, -2), (2, 3), (3, 2)])
def test_conv_mm_2d(level, shape, axes) -> None:
def test_conv_mm_2d(
level: Optional[int], shape: Sequence[int], axes: Tuple[int, int]
) -> None:
"""Compare mm and conv fully separable results."""
data = torch.randn(*shape).type(torch.float64)
fs_conv_coeff = fswavedec2(data, "haar", level=level, axes=axes)
Expand All @@ -114,7 +118,9 @@ def test_conv_mm_2d(level, shape, axes) -> None:
@pytest.mark.parametrize("level", [1, 2, 3, None])
@pytest.mark.parametrize("axes", [(-3, -2, -1), (-1, -2, -3), (2, 3, 1)])
@pytest.mark.parametrize("shape", [(5, 64, 128, 256)])
def test_conv_mm_3d(level, axes, shape) -> None:
def test_conv_mm_3d(
level: Optional[int], axes: Tuple[int, int, int], shape: Tuple[int, ...]
) -> None:
"""Compare mm and conv 3d fully separable results."""
data = torch.randn(*shape).type(torch.float64)
fs_conv_coeff = fswavedec3(data, "haar", level=level, axes=axes)
Expand Down
2 changes: 1 addition & 1 deletion tests/test_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@

class _MyHaarFilterBank:
@property
def filter_bank(self) -> Tuple[list, list, list, list]:
def filter_bank(self) -> Tuple[List[float], List[float], List[float], List[float]]:
"""Unscaled Haar wavelet filters."""
return (
[1 / 2, 1 / 2.0],
Expand Down

0 comments on commit 5f740e6

Please sign in to comment.