Skip to content

Commit

Permalink
Refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
felixblanke committed Jul 2, 2024
1 parent f1c6f55 commit 9d5a779
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 6 deletions.
5 changes: 2 additions & 3 deletions src/ptwt/matmul_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,7 @@ def __init__(
wavelet: Union[Wavelet, str],
level: Optional[int] = None,
*,
axis: Optional[int] = -1,
axis: int = -1,
orthogonalization: OrthogonalizeMethod = "qr",
odd_coeff_padding_mode: BoundaryMode = "zero",
) -> None:
Expand All @@ -223,8 +223,7 @@ def __init__(
level (int, optional): The level up to which to compute the fwt. If None,
the maximum level based on the signal length is chosen. Defaults to
None.
axis (int, optional): The axis we would like to transform.
Defaults to -1.
axis (int): The axis we would like to transform. Defaults to -1.
orthogonalization: The method used to orthogonalize
boundary filters, see :data:`ptwt.constants.OrthogonalizeMethod`.
Defaults to 'qr'.
Expand Down
6 changes: 3 additions & 3 deletions src/ptwt/matmul_transform_3.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,7 +275,7 @@ def _split_rec(
_split_rec(lll, "", 3, coeff_dict)
lll = coeff_dict["aaa"]
result_keys = list(
filter(lambda x: len(x) == 3 and not x == "aaa", coeff_dict.keys())
filter(lambda x: len(x) == 3 and x != "aaa", coeff_dict.keys())
)
coeff_dict = {
key: tensor for key, tensor in coeff_dict.items() if key in result_keys
Expand Down Expand Up @@ -394,7 +394,7 @@ def _construct_synthesis_matrices(

def _cat_coeff_recursive(self, input_dict: WaveletDetailDict) -> torch.Tensor:
done_dict = {}
a_initial_keys = list(filter(lambda x: x[0] == "a", input_dict.keys()))
a_initial_keys = filter(lambda x: x[0] == "a", input_dict.keys())
for a_key in a_initial_keys:
d_key = "d" + a_key[1:]
cat_d = input_dict[d_key]
Expand Down Expand Up @@ -470,7 +470,7 @@ def __call__(self, coefficients: WaveletCoeffNd) -> torch.Tensor:
"All coefficients on each level must have the same shape"
)

coeff_dict["a" * len(list(coeff_dict.keys())[-1])] = lll
coeff_dict["aaa"] = lll
lll = self._cat_coeff_recursive(coeff_dict)

for dim, mat in enumerate(self.ifwt_matrix_list[level - 1 - c_pos][::-1]):
Expand Down

0 comments on commit 9d5a779

Please sign in to comment.