Skip to content

Commit

Permalink
Merge pull request #291 from grlee77/idwtn_none
Browse files Browse the repository at this point in the history
BUG: idwtn should allow coefficients to be set as None
  • Loading branch information
rgommers committed Mar 9, 2017
2 parents 3d0ab16 + 0367caf commit 4fa4b45
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 11 deletions.
15 changes: 8 additions & 7 deletions pywt/_multidim.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,8 @@ def idwt2(coeffs, wavelet, mode='symmetric', axes=(-2, -1)):
----------
coeffs : tuple
(cA, (cH, cV, cD)) A tuple with approximation coefficients and three
details coefficients 2D arrays like from `dwt2()`
details coefficients 2D arrays like from `dwt2()`. If any of these
components are set to ``None``, it will be treated as zeros.
wavelet : Wavelet object or name string, or 2-tuple of wavelets
Wavelet to use. This can also be a tuple containing a wavelet to
apply along each axis in ``axes``.
Expand Down Expand Up @@ -113,10 +114,6 @@ def idwt2(coeffs, wavelet, mode='symmetric', axes=(-2, -1)):
raise ValueError("Expected 2 axes")

coeffs = {'aa': LL, 'da': HL, 'ad': LH, 'dd': HH}

# drop the keys corresponding to value = None
coeffs = dict((k, v) for k, v in coeffs.items() if v is not None)

return idwtn(coeffs, wavelet, mode, axes)


Expand Down Expand Up @@ -224,8 +221,8 @@ def idwtn(coeffs, wavelet, mode='symmetric', axes=None):
Parameters
----------
coeffs: dict
Dictionary as in output of `dwtn`. Missing or None items
will be treated as zeroes.
Dictionary as in output of ``dwtn``. Missing or ``None`` items
will be treated as zeros.
wavelet : Wavelet object or name string, or tuple of wavelets
Wavelet to use. This can also be a tuple containing a wavelet to
apply along each axis in ``axes``.
Expand All @@ -247,6 +244,10 @@ def idwtn(coeffs, wavelet, mode='symmetric', axes=None):
Original signal reconstructed from input data.
"""

# drop the keys corresponding to value = None
coeffs = dict((k, v) for k, v in coeffs.items() if v is not None)

# Raise error for invalid key combinations
coeffs = _fix_coeffs(coeffs)

Expand Down
38 changes: 34 additions & 4 deletions pywt/tests/test_multidim.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,10 +196,6 @@ def test_error_on_invalid_keys():
d = {'aa': LL, 'da': HL, 'ad': LH, 'dd': HH, 'ff': LH}
assert_raises(ValueError, pywt.idwtn, d, wavelet)

# a key whose value is None
d = {'aa': LL, 'da': HL, 'ad': LH, 'dd': None}
assert_raises(ValueError, pywt.idwtn, d, wavelet)

# mismatched key lengths
d = {'a': LL, 'da': HL, 'ad': LH, 'dd': HH}
assert_raises(ValueError, pywt.idwtn, d, wavelet)
Expand Down Expand Up @@ -268,6 +264,40 @@ def test_idwtn_axes():
assert_allclose(pywt.idwtn(coefs, 'haar', axes=(1, 1)), data, atol=1e-14)


def test_idwt2_none_coeffs():
data = np.array([[0, 1, 2, 3],
[1, 1, 1, 1],
[1, 4, 2, 8]])
data = data + 1j*data # test with complex data
cA, (cH, cV, cD) = pywt.dwt2(data, 'haar', axes=(1, 1))

# verify setting coefficients to None is the same as zeroing them
cD = np.zeros_like(cD)
result_zeros = pywt.idwt2((cA, (cH, cV, cD)), 'haar', axes=(1, 1))

cD = None
result_none = pywt.idwt2((cA, (cH, cV, cD)), 'haar', axes=(1, 1))

assert_equal(result_zeros, result_none)


def test_idwtn_none_coeffs():
data = np.array([[0, 1, 2, 3],
[1, 1, 1, 1],
[1, 4, 2, 8]])
data = data + 1j*data # test with complex data
coefs = pywt.dwtn(data, 'haar', axes=(1, 1))

# verify setting coefficients to None is the same as zeroing them
coefs['dd'] = np.zeros_like(coefs['dd'])
result_zeros = pywt.idwtn(coefs, 'haar', axes=(1, 1))

coefs['dd'] = None
result_none = pywt.idwtn(coefs, 'haar', axes=(1, 1))

assert_equal(result_zeros, result_none)


def test_idwt2_axes():
data = np.array([[0, 1, 2, 3],
[1, 1, 1, 1],
Expand Down

0 comments on commit 4fa4b45

Please sign in to comment.