diff --git a/tests/test_matrix_fwt.py b/tests/test_matrix_fwt.py index 9d63a718..cf9e8bce 100644 --- a/tests/test_matrix_fwt.py +++ b/tests/test_matrix_fwt.py @@ -140,7 +140,9 @@ def test_boundary_transform_1d( """Ensure matrix fwt reconstructions are pywt compatible.""" data_torch = torch.from_numpy(data.astype(np.float64)) wavelet = pywt.Wavelet(wavelet_str) - matrix_wavedec = MatrixWavedec(wavelet, level=level, boundary_orthogonalization=boundary) + matrix_wavedec = MatrixWavedec( + wavelet, level=level, boundary_orthogonalization=boundary + ) coeffs = matrix_wavedec(data_torch) matrix_waverec = MatrixWaverec(wavelet, boundary_orthogonalization=boundary) rec = matrix_waverec(coeffs) @@ -174,7 +176,9 @@ def test_matrix_transform_1d_rebuild( wavelet = pywt.Wavelet(wavelet_str) matrix_waverec = MatrixWaverec(wavelet, boundary_orthogonalization=boundary) for level in [2, 1]: - matrix_wavedec = MatrixWavedec(wavelet, level=level, boundary_orthogonalization=boundary) + matrix_wavedec = MatrixWavedec( + wavelet, level=level, boundary_orthogonalization=boundary + ) for data in data_list: data_torch = torch.from_numpy(data.astype(np.float64)) coeffs = matrix_wavedec(data_torch)