Skip to content

Commit

Permalink
black.
Browse files Browse the repository at this point in the history
  • Loading branch information
v0lta committed Feb 1, 2024
1 parent 6fe0002 commit 7451514
Show file tree
Hide file tree
Showing 6 changed files with 26 additions and 26 deletions.
6 changes: 3 additions & 3 deletions src/ptwt/separable_conv_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,9 +167,9 @@ def _separable_conv_waverecn(

approx: torch.Tensor = coeffs[0]
for level_dict in coeffs[1:]:
keys = list(level_dict.keys()) # type: ignore
level_dict["a" * max(map(len, keys))] = approx # type: ignore
approx = _separable_conv_idwtn(level_dict, wavelet) # type: ignore
keys = list(level_dict.keys()) # type: ignore
level_dict["a" * max(map(len, keys))] = approx # type: ignore
approx = _separable_conv_idwtn(level_dict, wavelet) # type: ignore
return approx


Expand Down
15 changes: 9 additions & 6 deletions tests/test_convolution_fwt_3.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
"""Test our 3d for loop-convolution based fwt code."""

import typing
from typing import List, Union, Any, Dict
import numpy.typing as npt
from typing import Any, Dict, List, Union

import numpy as np
import numpy.typing as npt
import pytest
import pywt
import torch
Expand All @@ -13,17 +13,20 @@
from ptwt.constants import BoundaryMode


def _expand_dims(batch_list: List[Union[npt.NDArray[Any], Dict[Any, Any]]]
) -> List[Any]:
def _expand_dims(
batch_list: List[Union[npt.NDArray[Any], Dict[Any, Any]]]
) -> List[Any]:
for pos, bel in enumerate(batch_list):
if isinstance(bel, np.ndarray):
batch_list[pos] = np.expand_dims(bel, 0)
elif isinstance(bel, dict):
for key, item in bel.items():
batch_list[pos][key] = np.expand_dims(item, 0)
else:
raise TypeError("Argument type not supported,\
batch_list element should have been a dict.")
raise TypeError(
"Argument type not supported,\
batch_list element should have been a dict."
)
return batch_list


Expand Down
10 changes: 3 additions & 7 deletions tests/test_cwt.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""Test the continuous transformation code."""

from typing import Union, Any
from typing import Any, Union

import numpy as np
import pytest
Expand Down Expand Up @@ -35,9 +35,7 @@
@pytest.mark.parametrize("scales", [np.arange(1, 16), 5.0, torch.arange(1, 15)])
@pytest.mark.parametrize("samples", [31, 32])
@pytest.mark.parametrize("wavelet", continuous_wavelets)
def test_cwt(
wavelet: str, samples: int, scales: Any
) -> None:
def test_cwt(wavelet: str, samples: int, scales: Any) -> None:
"""Test the cwt implementation for various wavelets."""
t = np.linspace(-1, 1, samples, endpoint=False)
sig = signal.chirp(t, f0=1, f1=50, t1=10, method="linear")
Expand Down Expand Up @@ -96,9 +94,7 @@ def test_nn_schannon_wavefun(type: str, grid_size: int) -> None:
@pytest.mark.parametrize("scales", [np.arange(1, 16), 5.0, torch.arange(1, 15)])
@pytest.mark.parametrize("samples", [31, 32])
@pytest.mark.parametrize("cuda", [False, True])
def test_nn_cwt(
samples: int, scales: Any, cuda: bool
) -> None:
def test_nn_cwt(samples: int, scales: Any, cuda: bool) -> None:
"""Test the cwt using a differentiable continuous wavelet."""
pywt_shannon = pywt.ContinuousWavelet("shan1-1")
ptwt_shannon = _ShannonWavelet("shan1-1")
Expand Down
9 changes: 3 additions & 6 deletions tests/test_jit.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""Ensure pytorch's torch.jit.trace feature works properly."""

from typing import List, NamedTuple, Optional, Union, Any, Tuple, Dict
from typing import Any, Dict, List, NamedTuple, Optional, Tuple, Union

import numpy as np
import pytest
Expand Down Expand Up @@ -121,9 +121,7 @@ def test_conv_fwt_jit_2d() -> None:
assert np.allclose(rec.squeeze(1).numpy(), data.numpy(), atol=1e-7)


def _to_jit_wavedec_3(
data: torch.Tensor, wavelet: str
) -> List[torch.Tensor]:
def _to_jit_wavedec_3(data: torch.Tensor, wavelet: str) -> List[torch.Tensor]:
"""Ensure uniform datatypes in lists for the tracer.
Going from List[Union[torch.Tensor, Dict[str, torch.Tensor]]] to List[torch.Tensor]
Expand All @@ -141,8 +139,7 @@ def _to_jit_wavedec_3(
return coeff2


def _to_jit_waverec_3(data: List[torch.Tensor],
wavelet: pywt.Wavelet) -> torch.Tensor:
def _to_jit_waverec_3(data: List[torch.Tensor], wavelet: pywt.Wavelet) -> torch.Tensor:
"""Undo the stacking from the jit wavedec3 wrapper."""
d_unstack: List[Union[torch.Tensor, Dict[str, torch.Tensor]]] = [data[0]]
keys = ("aad", "ada", "add", "daa", "dad", "dda", "ddd")
Expand Down
4 changes: 2 additions & 2 deletions tests/test_matrix_fwt.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@

# Written by moritz ( @ wolter.tech ) in 2021

from typing import List, Any
import numpy.typing as npt
from typing import Any, List

import numpy as np
import numpy.typing as npt
import pytest
import pywt
import torch
Expand Down
8 changes: 6 additions & 2 deletions tests/test_packets.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,8 +185,12 @@ def test_boundary_matrix_packets2(
@pytest.mark.parametrize("transform_mode", [False, True])
@pytest.mark.parametrize("multiple_transforms", [False, True])
def test_1d_packets(
max_lev: int, wavelet_str: str, boundary: str, batch_size: int,
transform_mode: bool, multiple_transforms: bool
max_lev: int,
wavelet_str: str,
boundary: str,
batch_size: int,
transform_mode: bool,
multiple_transforms: bool,
) -> None:
"""Ensure pywt and ptwt produce equivalent wavelet 1d packet trees."""
_compare_trees1(
Expand Down

0 comments on commit 7451514

Please sign in to comment.