diff --git a/pywt/_multidim.py b/pywt/_multidim.py index c01a58fe..1cc54b23 100644 --- a/pywt/_multidim.py +++ b/pywt/_multidim.py @@ -83,7 +83,8 @@ def idwt2(coeffs, wavelet, mode='symmetric', axes=(-2, -1)): ---------- coeffs : tuple (cA, (cH, cV, cD)) A tuple with approximation coefficients and three - details coefficients 2D arrays like from `dwt2()` + details coefficients 2D arrays like from `dwt2()`. If any of these + components are set to ``None``, it will be treated as zeros. wavelet : Wavelet object or name string, or 2-tuple of wavelets Wavelet to use. This can also be a tuple containing a wavelet to apply along each axis in ``axes``. @@ -113,10 +114,6 @@ def idwt2(coeffs, wavelet, mode='symmetric', axes=(-2, -1)): raise ValueError("Expected 2 axes") coeffs = {'aa': LL, 'da': HL, 'ad': LH, 'dd': HH} - - # drop the keys corresponding to value = None - coeffs = dict((k, v) for k, v in coeffs.items() if v is not None) - return idwtn(coeffs, wavelet, mode, axes) @@ -224,8 +221,8 @@ def idwtn(coeffs, wavelet, mode='symmetric', axes=None): Parameters ---------- coeffs: dict - Dictionary as in output of `dwtn`. Missing or None items - will be treated as zeroes. + Dictionary as in output of ``dwtn``. Missing or ``None`` items + will be treated as zeros. wavelet : Wavelet object or name string, or tuple of wavelets Wavelet to use. This can also be a tuple containing a wavelet to apply along each axis in ``axes``. @@ -247,6 +244,10 @@ def idwtn(coeffs, wavelet, mode='symmetric', axes=None): Original signal reconstructed from input data. """ + + # drop the keys corresponding to value = None + coeffs = dict((k, v) for k, v in coeffs.items() if v is not None) + # Raise error for invalid key combinations coeffs = _fix_coeffs(coeffs) diff --git a/pywt/tests/test_multidim.py b/pywt/tests/test_multidim.py index 6ddcd9bb..c9668ab0 100644 --- a/pywt/tests/test_multidim.py +++ b/pywt/tests/test_multidim.py @@ -196,10 +196,6 @@ def test_error_on_invalid_keys(): d = {'aa': LL, 'da': HL, 'ad': LH, 'dd': HH, 'ff': LH} assert_raises(ValueError, pywt.idwtn, d, wavelet) - # a key whose value is None - d = {'aa': LL, 'da': HL, 'ad': LH, 'dd': None} - assert_raises(ValueError, pywt.idwtn, d, wavelet) - # mismatched key lengths d = {'a': LL, 'da': HL, 'ad': LH, 'dd': HH} assert_raises(ValueError, pywt.idwtn, d, wavelet) @@ -268,6 +264,40 @@ def test_idwtn_axes(): assert_allclose(pywt.idwtn(coefs, 'haar', axes=(1, 1)), data, atol=1e-14) +def test_idwt2_none_coeffs(): + data = np.array([[0, 1, 2, 3], + [1, 1, 1, 1], + [1, 4, 2, 8]]) + data = data + 1j*data # test with complex data + cA, (cH, cV, cD) = pywt.dwt2(data, 'haar', axes=(1, 1)) + + # verify setting coefficients to None is the same as zeroing them + cD = np.zeros_like(cD) + result_zeros = pywt.idwt2((cA, (cH, cV, cD)), 'haar', axes=(1, 1)) + + cD = None + result_none = pywt.idwt2((cA, (cH, cV, cD)), 'haar', axes=(1, 1)) + + assert_equal(result_zeros, result_none) + + +def test_idwtn_none_coeffs(): + data = np.array([[0, 1, 2, 3], + [1, 1, 1, 1], + [1, 4, 2, 8]]) + data = data + 1j*data # test with complex data + coefs = pywt.dwtn(data, 'haar', axes=(1, 1)) + + # verify setting coefficients to None is the same as zeroing them + coefs['dd'] = np.zeros_like(coefs['dd']) + result_zeros = pywt.idwtn(coefs, 'haar', axes=(1, 1)) + + coefs['dd'] = None + result_none = pywt.idwtn(coefs, 'haar', axes=(1, 1)) + + assert_equal(result_zeros, result_none) + + def test_idwt2_axes(): data = np.array([[0, 1, 2, 3], [1, 1, 1, 1],