From 48dfe3195197b1f9ea0acb6458c127ed720874a6 Mon Sep 17 00:00:00 2001 From: Faisal Alsrheed <47912291+Faisal-Alsrheed@users.noreply.github.com> Date: Tue, 23 Apr 2024 09:21:36 +0000 Subject: [PATCH 01/10] Fix and test math functions for jax backend --- keras/src/backend/jax/math.py | 9 ++ keras/src/backend/jax/math_test.py | 140 +++++++++++++++++++++++++++++ 2 files changed, 149 insertions(+) create mode 100644 keras/src/backend/jax/math_test.py diff --git a/keras/src/backend/jax/math.py b/keras/src/backend/jax/math.py index 361eeee89173..19a10858637b 100644 --- a/keras/src/backend/jax/math.py +++ b/keras/src/backend/jax/math.py @@ -17,6 +17,8 @@ def segment_sum(data, segment_ids, num_segments=None, sorted=False): "Argument `num_segments` must be set when using the JAX backend. " "Received: num_segments=None" ) + if jnp.any(segment_ids < 0) or jnp.any(segment_ids >= num_segments): + raise ValueError("Segment ID out of range") return jax.ops.segment_sum( data, segment_ids, num_segments, indices_are_sorted=sorted ) @@ -28,6 +30,8 @@ def segment_max(data, segment_ids, num_segments=None, sorted=False): "Argument `num_segments` must be set when using the JAX backend. " "Received: num_segments=None" ) + if jnp.any(segment_ids < 0) or jnp.any(segment_ids >= num_segments): + raise ValueError("Segment ID out of range") return jax.ops.segment_max( data, segment_ids, num_segments, indices_are_sorted=sorted ) @@ -201,7 +205,12 @@ def istft( window="hann", center=True, ): + x = _get_complex_tensor_from_tuple(x) + if x.ndim != 2: + raise ValueError( + "Input `x` must be a 2D tensor. Received: x.ndim={x.ndim}" + ) dtype = jnp.real(x).dtype expected_output_len = fft_length + sequence_stride * (x.shape[-2] - 1) diff --git a/keras/src/backend/jax/math_test.py b/keras/src/backend/jax/math_test.py new file mode 100644 index 000000000000..fb6a1c363c5b --- /dev/null +++ b/keras/src/backend/jax/math_test.py @@ -0,0 +1,140 @@ +import jax.numpy as jnp +import pytest + +from keras.src import backend +from keras.src import testing +from keras.src.backend.jax.math import _get_complex_tensor_from_tuple +from keras.src.backend.jax.math import istft +from keras.src.backend.jax.math import qr +from keras.src.backend.jax.math import segment_max +from keras.src.backend.jax.math import segment_sum +from keras.src.backend.jax.math import stft + + +@pytest.mark.skipif( + backend.backend() != "jax", reason="Testing Jax functions only" +) +class TestJaxMathErrors(testing.TestCase): + + def test_segment_sum_no_num_segments(self): + data = jnp.array([1, 2, 3, 4]) + segment_ids = jnp.array([0, 0, 1, 1]) + with self.assertRaisesRegex( + ValueError, + "Argument `num_segments` must be set when using the JAX backend.", + ): + segment_sum(data, segment_ids) + + def test_segment_sum_invalid_segments(self): + data = jnp.array([1, 2, 3, 4]) + segment_ids = jnp.array([0, 1, 2, 3]) + num_segments = 2 + with self.assertRaisesRegex(ValueError, "Segment ID out of range"): + segment_sum(data, segment_ids, num_segments) + + def test_segment_max_no_num_segments(self): + data = jnp.array([1, 2, 3, 4]) + segment_ids = jnp.array([0, 0, 1, 1]) + with self.assertRaisesRegex( + ValueError, + "Argument `num_segments` must be set when using the JAX backend.", + ): + segment_max(data, segment_ids) + + def test_segment_max_invalid_segments(self): + data = jnp.array([1, 2, 3, 4]) + segment_ids = jnp.array( + [0, 1, 2, 3] + ) # Last index is out of range for num_segments=2 + num_segments = 2 + with self.assertRaisesRegex(ValueError, "Segment ID out of range"): + segment_max(data, segment_ids, num_segments) + + def test_qr_invalid_mode(self): + x = jnp.array([[1, 2], [3, 4]]) + invalid_mode = "invalid_mode" + with self.assertRaisesRegex( + ValueError, "Expected one of {'reduced', 'complete'}." + ): + qr(x, mode=invalid_mode) + + def test_get_complex_tensor_from_tuple_valid_input(self): + real = jnp.array([1.0, 2.0, 3.0]) + imag = jnp.array([4.0, 5.0, 6.0]) + complex_tensor = _get_complex_tensor_from_tuple((real, imag)) + self.assertTrue(jnp.iscomplexobj(complex_tensor)) + self.assertTrue(jnp.array_equal(jnp.real(complex_tensor), real)) + self.assertTrue(jnp.array_equal(jnp.imag(complex_tensor), imag)) + + def test_invalid_get_complex_tensor_from_tuple_input_type(self): + with self.assertRaisesRegex(ValueError, "Input `x` should be a tuple"): + _get_complex_tensor_from_tuple(jnp.array([1.0, 2.0, 3.0])) + + def test_invalid_get_complex_tensor_from_tuple_input_length(self): + with self.assertRaisesRegex(ValueError, "Input `x` should be a tuple"): + _get_complex_tensor_from_tuple( + ( + jnp.array([1.0, 2.0, 3.0]), + jnp.array([4.0, 5.0, 6.0]), + jnp.array([7.0, 8.0, 9.0]), + ) + ) + + def test_mismatched_shapes(self): + real = jnp.array([1.0, 2.0, 3.0]) + imag = jnp.array([4.0, 5.0]) + with self.assertRaisesRegex(ValueError, "Both the real and imaginary"): + _get_complex_tensor_from_tuple((real, imag)) + + def test_invalid_dtype(self): + real = jnp.array([1, 2, 3]) + imag = jnp.array([4.0, 5.0, 6.0]) + with self.assertRaisesRegex(ValueError, "At least one tensor in input"): + _get_complex_tensor_from_tuple((real, imag)) + + def test_stft_invalid_input_type(self): + x = jnp.array([1, 2, 3, 4]) # Integer input + sequence_length = 2 + sequence_stride = 1 + fft_length = 4 + with self.assertRaisesRegex(TypeError, "`float32` or `float64`"): + stft(x, sequence_length, sequence_stride, fft_length) + + def test_invalid_fft_length(self): + x = jnp.array([1.0, 2.0, 3.0, 4.0]) + sequence_length = 4 + sequence_stride = 1 + fft_length = 2 + with self.assertRaisesRegex(ValueError, "`fft_length` must equal or"): + stft(x, sequence_length, sequence_stride, fft_length) + + def test_stft_invalid_window(self): + x = jnp.array([1.0, 2.0, 3.0, 4.0]) + sequence_length = 2 + sequence_stride = 1 + fft_length = 4 + window = "invalid_window" + with self.assertRaisesRegex(ValueError, "If a string is passed to"): + stft(x, sequence_length, sequence_stride, fft_length, window=window) + + def test_stft_invalid_window_shape(self): + x = jnp.array([1.0, 2.0, 3.0, 4.0]) + sequence_length = 2 + sequence_stride = 1 + fft_length = 4 + window = jnp.ones((sequence_length + 1)) + with self.assertRaisesRegex(ValueError, "The shape of `window` must"): + stft(x, sequence_length, sequence_stride, fft_length, window=window) + + def test_istft_invalid_window_shape(self): + x = (jnp.array([1.0, 2.0, 3.0, 4.0]), jnp.array([0.0, 0.0, 0.0, 0.0])) + sequence_length = 2 + sequence_stride = 1 + fft_length = 4 + window = jnp.ones((sequence_length + 1)) # Invalid window shape + with self.assertRaisesRegex( + ValueError, "Input `x` must be a 2D tensor" + ): + istft( + x, sequence_length, sequence_stride, fft_length, window=window + ) From 05d4f9cb2dccea37cb1a3af698ff5a1c84b54b48 Mon Sep 17 00:00:00 2001 From: Faisal Alsrheed <47912291+Faisal-Alsrheed@users.noreply.github.com> Date: Tue, 23 Apr 2024 09:32:18 +0000 Subject: [PATCH 02/10] run /workspaces/keras/shell/format.sh --- keras/api/_tf_keras/keras/losses/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/keras/api/_tf_keras/keras/losses/__init__.py b/keras/api/_tf_keras/keras/losses/__init__.py index 9a134077e032..be173a43c4d5 100644 --- a/keras/api/_tf_keras/keras/losses/__init__.py +++ b/keras/api/_tf_keras/keras/losses/__init__.py @@ -8,9 +8,9 @@ from keras.src.losses import get from keras.src.losses import serialize from keras.src.losses.loss import Loss +from keras.src.losses.losses import CTC from keras.src.losses.losses import BinaryCrossentropy from keras.src.losses.losses import BinaryFocalCrossentropy -from keras.src.losses.losses import CTC from keras.src.losses.losses import CategoricalCrossentropy from keras.src.losses.losses import CategoricalFocalCrossentropy from keras.src.losses.losses import CategoricalHinge From a85f36806df5d43ef1c64e7922aa023366cf929b Mon Sep 17 00:00:00 2001 From: Faisal Alsrheed <47912291+Faisal-Alsrheed@users.noreply.github.com> Date: Tue, 23 Apr 2024 12:57:31 +0000 Subject: [PATCH 03/10] refix --- keras/src/backend/jax/math.py | 8 -------- keras/src/backend/jax/math_test.py | 16 ---------------- 2 files changed, 24 deletions(-) diff --git a/keras/src/backend/jax/math.py b/keras/src/backend/jax/math.py index 19a10858637b..d9ac5fc8a951 100644 --- a/keras/src/backend/jax/math.py +++ b/keras/src/backend/jax/math.py @@ -17,8 +17,6 @@ def segment_sum(data, segment_ids, num_segments=None, sorted=False): "Argument `num_segments` must be set when using the JAX backend. " "Received: num_segments=None" ) - if jnp.any(segment_ids < 0) or jnp.any(segment_ids >= num_segments): - raise ValueError("Segment ID out of range") return jax.ops.segment_sum( data, segment_ids, num_segments, indices_are_sorted=sorted ) @@ -30,8 +28,6 @@ def segment_max(data, segment_ids, num_segments=None, sorted=False): "Argument `num_segments` must be set when using the JAX backend. " "Received: num_segments=None" ) - if jnp.any(segment_ids < 0) or jnp.any(segment_ids >= num_segments): - raise ValueError("Segment ID out of range") return jax.ops.segment_max( data, segment_ids, num_segments, indices_are_sorted=sorted ) @@ -207,10 +203,6 @@ def istft( ): x = _get_complex_tensor_from_tuple(x) - if x.ndim != 2: - raise ValueError( - "Input `x` must be a 2D tensor. Received: x.ndim={x.ndim}" - ) dtype = jnp.real(x).dtype expected_output_len = fft_length + sequence_stride * (x.shape[-2] - 1) diff --git a/keras/src/backend/jax/math_test.py b/keras/src/backend/jax/math_test.py index fb6a1c363c5b..458217800310 100644 --- a/keras/src/backend/jax/math_test.py +++ b/keras/src/backend/jax/math_test.py @@ -25,13 +25,6 @@ def test_segment_sum_no_num_segments(self): ): segment_sum(data, segment_ids) - def test_segment_sum_invalid_segments(self): - data = jnp.array([1, 2, 3, 4]) - segment_ids = jnp.array([0, 1, 2, 3]) - num_segments = 2 - with self.assertRaisesRegex(ValueError, "Segment ID out of range"): - segment_sum(data, segment_ids, num_segments) - def test_segment_max_no_num_segments(self): data = jnp.array([1, 2, 3, 4]) segment_ids = jnp.array([0, 0, 1, 1]) @@ -41,15 +34,6 @@ def test_segment_max_no_num_segments(self): ): segment_max(data, segment_ids) - def test_segment_max_invalid_segments(self): - data = jnp.array([1, 2, 3, 4]) - segment_ids = jnp.array( - [0, 1, 2, 3] - ) # Last index is out of range for num_segments=2 - num_segments = 2 - with self.assertRaisesRegex(ValueError, "Segment ID out of range"): - segment_max(data, segment_ids, num_segments) - def test_qr_invalid_mode(self): x = jnp.array([[1, 2], [3, 4]]) invalid_mode = "invalid_mode" From 19787912c0c064f806a77230107d3e58deaf71c2 Mon Sep 17 00:00:00 2001 From: Faisal Alsrheed <47912291+Faisal-Alsrheed@users.noreply.github.com> Date: Tue, 23 Apr 2024 13:31:59 +0000 Subject: [PATCH 04/10] fix --- keras/src/backend/jax/math.py | 1 - keras/src/backend/jax/math_test.py | 13 ------------- 2 files changed, 14 deletions(-) diff --git a/keras/src/backend/jax/math.py b/keras/src/backend/jax/math.py index d9ac5fc8a951..361eeee89173 100644 --- a/keras/src/backend/jax/math.py +++ b/keras/src/backend/jax/math.py @@ -201,7 +201,6 @@ def istft( window="hann", center=True, ): - x = _get_complex_tensor_from_tuple(x) dtype = jnp.real(x).dtype diff --git a/keras/src/backend/jax/math_test.py b/keras/src/backend/jax/math_test.py index 458217800310..cbc7466af851 100644 --- a/keras/src/backend/jax/math_test.py +++ b/keras/src/backend/jax/math_test.py @@ -109,16 +109,3 @@ def test_stft_invalid_window_shape(self): window = jnp.ones((sequence_length + 1)) with self.assertRaisesRegex(ValueError, "The shape of `window` must"): stft(x, sequence_length, sequence_stride, fft_length, window=window) - - def test_istft_invalid_window_shape(self): - x = (jnp.array([1.0, 2.0, 3.0, 4.0]), jnp.array([0.0, 0.0, 0.0, 0.0])) - sequence_length = 2 - sequence_stride = 1 - fft_length = 4 - window = jnp.ones((sequence_length + 1)) # Invalid window shape - with self.assertRaisesRegex( - ValueError, "Input `x` must be a 2D tensor" - ): - istft( - x, sequence_length, sequence_stride, fft_length, window=window - ) From 61076421c5aadb851c6ca8d9d1b876b80395c2dc Mon Sep 17 00:00:00 2001 From: Faisal Alsrheed <47912291+Faisal-Alsrheed@users.noreply.github.com> Date: Tue, 23 Apr 2024 14:39:13 +0000 Subject: [PATCH 05/10] fix _get_complex_tensor_from_tuple --- keras/src/backend/jax/math.py | 6 ++++++ keras/src/backend/jax/math_test.py | 30 ++++++++++++++++++++++++++++-- 2 files changed, 34 insertions(+), 2 deletions(-) diff --git a/keras/src/backend/jax/math.py b/keras/src/backend/jax/math.py index 361eeee89173..a00112099854 100644 --- a/keras/src/backend/jax/math.py +++ b/keras/src/backend/jax/math.py @@ -95,6 +95,12 @@ def _get_complex_tensor_from_tuple(x): "Both the real and imaginary parts should have the same shape. " f"Received: x[0].shape = {real.shape}, x[1].shape = {imag.shape}" ) + # Check minimum dimension requirement (if there's such a requirement) + if len(real.shape) < 2: + raise ValueError( + "Input tensors (real and imaginary parts) must have at least two dimensions. " + f"Received real shape {real.shape}, imag shape {imag.shape}." + ) # Ensure dtype is float. if not jnp.issubdtype(real.dtype, jnp.floating) or not jnp.issubdtype( imag.dtype, jnp.floating diff --git a/keras/src/backend/jax/math_test.py b/keras/src/backend/jax/math_test.py index cbc7466af851..102782c08669 100644 --- a/keras/src/backend/jax/math_test.py +++ b/keras/src/backend/jax/math_test.py @@ -9,6 +9,7 @@ from keras.src.backend.jax.math import segment_max from keras.src.backend.jax.math import segment_sum from keras.src.backend.jax.math import stft +from keras.src.backend.jax.math import istft @pytest.mark.skipif( @@ -43,8 +44,8 @@ def test_qr_invalid_mode(self): qr(x, mode=invalid_mode) def test_get_complex_tensor_from_tuple_valid_input(self): - real = jnp.array([1.0, 2.0, 3.0]) - imag = jnp.array([4.0, 5.0, 6.0]) + real = jnp.array([[1.0, 2.0, 3.0]]) + imag = jnp.array([[4.0, 5.0, 6.0]]) complex_tensor = _get_complex_tensor_from_tuple((real, imag)) self.assertTrue(jnp.iscomplexobj(complex_tensor)) self.assertTrue(jnp.array_equal(jnp.real(complex_tensor), real)) @@ -109,3 +110,28 @@ def test_stft_invalid_window_shape(self): window = jnp.ones((sequence_length + 1)) with self.assertRaisesRegex(ValueError, "The shape of `window` must"): stft(x, sequence_length, sequence_stride, fft_length, window=window) + + def test_istft_invalid_window_shape(self): + x = (jnp.array([1.0, 2.0]), jnp.array([3.0, 4.0])) + sequence_length = 2 + sequence_stride = 1 + fft_length = 4 + incorrect_window = jnp.ones((sequence_length + 1,)) + with self.assertRaisesRegex(ValueError, "The shape of `window` must be equal to \[sequence_length\]."): + istft(x, sequence_length, sequence_stride, fft_length, window=incorrect_window) + + def test_invalid_dtype(self): + real = jnp.array([[1, 2, 3]]) + imag = jnp.array([[4.0, 5.0, 6.0]]) + expected_message = "is not of type float" + with self.assertRaisesRegex(ValueError, expected_message): + _get_complex_tensor_from_tuple((real, imag)) + + def test_istft_invalid_window_shape(self): + x = (jnp.array([[1.0, 2.0]]), jnp.array([[3.0, 4.0]])) # Now two-dimensional + sequence_length = 2 + sequence_stride = 1 + fft_length = 4 + incorrect_window = jnp.ones((sequence_length + 1,)) + with self.assertRaisesRegex(ValueError, "The shape of `window` must be equal to \[sequence_length\]."): + istft(x, sequence_length, sequence_stride, fft_length, window=incorrect_window) From 35c5da7dd0350edc35aba0c9524a8548ae6e14fc Mon Sep 17 00:00:00 2001 From: Faisal Alsrheed <47912291+Faisal-Alsrheed@users.noreply.github.com> Date: Tue, 23 Apr 2024 15:33:33 +0000 Subject: [PATCH 06/10] fix --- keras/src/backend/jax/math.py | 5 ++-- keras/src/backend/jax/math_test.py | 40 +++++++++++++----------------- 2 files changed, 20 insertions(+), 25 deletions(-) diff --git a/keras/src/backend/jax/math.py b/keras/src/backend/jax/math.py index a00112099854..195e51036b16 100644 --- a/keras/src/backend/jax/math.py +++ b/keras/src/backend/jax/math.py @@ -95,10 +95,11 @@ def _get_complex_tensor_from_tuple(x): "Both the real and imaginary parts should have the same shape. " f"Received: x[0].shape = {real.shape}, x[1].shape = {imag.shape}" ) - # Check minimum dimension requirement (if there's such a requirement) + # Check minimum dimension requirement if len(real.shape) < 2: raise ValueError( - "Input tensors (real and imaginary parts) must have at least two dimensions. " + "Input tensors (real and imaginary parts)" + "must have at least two dimensions." f"Received real shape {real.shape}, imag shape {imag.shape}." ) # Ensure dtype is float. diff --git a/keras/src/backend/jax/math_test.py b/keras/src/backend/jax/math_test.py index 102782c08669..48b862643bbc 100644 --- a/keras/src/backend/jax/math_test.py +++ b/keras/src/backend/jax/math_test.py @@ -9,7 +9,6 @@ from keras.src.backend.jax.math import segment_max from keras.src.backend.jax.math import segment_sum from keras.src.backend.jax.math import stft -from keras.src.backend.jax.math import istft @pytest.mark.skipif( @@ -71,14 +70,8 @@ def test_mismatched_shapes(self): with self.assertRaisesRegex(ValueError, "Both the real and imaginary"): _get_complex_tensor_from_tuple((real, imag)) - def test_invalid_dtype(self): - real = jnp.array([1, 2, 3]) - imag = jnp.array([4.0, 5.0, 6.0]) - with self.assertRaisesRegex(ValueError, "At least one tensor in input"): - _get_complex_tensor_from_tuple((real, imag)) - def test_stft_invalid_input_type(self): - x = jnp.array([1, 2, 3, 4]) # Integer input + x = jnp.array([1, 2, 3, 4]) sequence_length = 2 sequence_stride = 1 fft_length = 4 @@ -111,27 +104,28 @@ def test_stft_invalid_window_shape(self): with self.assertRaisesRegex(ValueError, "The shape of `window` must"): stft(x, sequence_length, sequence_stride, fft_length, window=window) - def test_istft_invalid_window_shape(self): - x = (jnp.array([1.0, 2.0]), jnp.array([3.0, 4.0])) - sequence_length = 2 - sequence_stride = 1 - fft_length = 4 - incorrect_window = jnp.ones((sequence_length + 1,)) - with self.assertRaisesRegex(ValueError, "The shape of `window` must be equal to \[sequence_length\]."): - istft(x, sequence_length, sequence_stride, fft_length, window=incorrect_window) - - def test_invalid_dtype(self): + def test_invalid_not_float_get_complex_tensor_from_tuple_dtype(self): real = jnp.array([[1, 2, 3]]) imag = jnp.array([[4.0, 5.0, 6.0]]) expected_message = "is not of type float" with self.assertRaisesRegex(ValueError, expected_message): _get_complex_tensor_from_tuple((real, imag)) - def test_istft_invalid_window_shape(self): - x = (jnp.array([[1.0, 2.0]]), jnp.array([[3.0, 4.0]])) # Now two-dimensional + def test_istft_invalid_window_shape2(self): + x = (jnp.array([[1.0, 2.0]]), jnp.array([[3.0, 4.0]])) sequence_length = 2 sequence_stride = 1 fft_length = 4 - incorrect_window = jnp.ones((sequence_length + 1,)) - with self.assertRaisesRegex(ValueError, "The shape of `window` must be equal to \[sequence_length\]."): - istft(x, sequence_length, sequence_stride, fft_length, window=incorrect_window) + incorrect_window = jnp.ones( + (sequence_length + 1,) + ) # Incorrect window length + with self.assertRaisesRegex( + ValueError, "The shape of `window` must be equal to" + ): + istft( + x, + sequence_length, + sequence_stride, + fft_length, + window=incorrect_window, + ) From 6217de8b39a42b2aac81429d27f54f8220df8f3c Mon Sep 17 00:00:00 2001 From: Faisal Alsrheed <47912291+Faisal-Alsrheed@users.noreply.github.com> Date: Tue, 23 Apr 2024 18:37:50 +0000 Subject: [PATCH 07/10] refix --- keras/src/backend/jax/math.py | 7 ------- keras/src/backend/jax/math_test.py | 2 +- 2 files changed, 1 insertion(+), 8 deletions(-) diff --git a/keras/src/backend/jax/math.py b/keras/src/backend/jax/math.py index 195e51036b16..361eeee89173 100644 --- a/keras/src/backend/jax/math.py +++ b/keras/src/backend/jax/math.py @@ -95,13 +95,6 @@ def _get_complex_tensor_from_tuple(x): "Both the real and imaginary parts should have the same shape. " f"Received: x[0].shape = {real.shape}, x[1].shape = {imag.shape}" ) - # Check minimum dimension requirement - if len(real.shape) < 2: - raise ValueError( - "Input tensors (real and imaginary parts)" - "must have at least two dimensions." - f"Received real shape {real.shape}, imag shape {imag.shape}." - ) # Ensure dtype is float. if not jnp.issubdtype(real.dtype, jnp.floating) or not jnp.issubdtype( imag.dtype, jnp.floating diff --git a/keras/src/backend/jax/math_test.py b/keras/src/backend/jax/math_test.py index 48b862643bbc..b09c22ac1ce2 100644 --- a/keras/src/backend/jax/math_test.py +++ b/keras/src/backend/jax/math_test.py @@ -118,7 +118,7 @@ def test_istft_invalid_window_shape2(self): fft_length = 4 incorrect_window = jnp.ones( (sequence_length + 1,) - ) # Incorrect window length + ) with self.assertRaisesRegex( ValueError, "The shape of `window` must be equal to" ): From 7f6e474af1f6a60f4c9452f17f2d3e2e3d307838 Mon Sep 17 00:00:00 2001 From: Faisal Alsrheed <47912291+Faisal-Alsrheed@users.noreply.github.com> Date: Wed, 24 Apr 2024 06:49:05 +0000 Subject: [PATCH 08/10] Fix istft function to handle inputs with less than 2 dimensions --- keras/src/backend/jax/math.py | 3 ++ keras/src/backend/jax/math_test.py | 61 ++++++++++++++++++++++++------ 2 files changed, 52 insertions(+), 12 deletions(-) diff --git a/keras/src/backend/jax/math.py b/keras/src/backend/jax/math.py index 361eeee89173..32e8224ae704 100644 --- a/keras/src/backend/jax/math.py +++ b/keras/src/backend/jax/math.py @@ -204,6 +204,9 @@ def istft( x = _get_complex_tensor_from_tuple(x) dtype = jnp.real(x).dtype + if len(x.shape) < 2: + raise ValueError("Input `x` must have at least 2 dimensions.") + expected_output_len = fft_length + sequence_stride * (x.shape[-2] - 1) l_pad = (fft_length - sequence_length) // 2 r_pad = fft_length - sequence_length - l_pad diff --git a/keras/src/backend/jax/math_test.py b/keras/src/backend/jax/math_test.py index b09c22ac1ce2..535d76ecb206 100644 --- a/keras/src/backend/jax/math_test.py +++ b/keras/src/backend/jax/math_test.py @@ -1,3 +1,4 @@ +import jax import jax.numpy as jnp import pytest @@ -64,12 +65,44 @@ def test_invalid_get_complex_tensor_from_tuple_input_length(self): ) ) - def test_mismatched_shapes(self): + def test_get_complex_tensor_from_tuple_mismatched_shapes(self): real = jnp.array([1.0, 2.0, 3.0]) imag = jnp.array([4.0, 5.0]) with self.assertRaisesRegex(ValueError, "Both the real and imaginary"): _get_complex_tensor_from_tuple((real, imag)) + def test_invalid_not_float_get_complex_tensor_from_tuple_dtype(self): + real = jnp.array([[1, 2, 3]]) + imag = jnp.array([[4.0, 5.0, 6.0]]) + expected_message = "is not of type float" + with self.assertRaisesRegex(ValueError, expected_message): + _get_complex_tensor_from_tuple((real, imag)) + + def test_get_complex_tensor_from_tuple_complex_tensor_creation(self): + real = jnp.array([1.0, 2.0]) + imag = jnp.array([3.0, 4.0]) + expected_complex = jax.lax.complex(real, imag) + result = _get_complex_tensor_from_tuple((real, imag)) + self.assertTrue( + jnp.array_equal(result, expected_complex), + msg="Complex tensor not created correctly.", + ) + + def test_get_complex_tensor_from_tuple_output_completeness(self): + real = jnp.array([1.0, 2.0]) + imag = jnp.array([3.0, 4.0]) + complex_tensor = _get_complex_tensor_from_tuple((real, imag)) + self.assertEqual( + jnp.real(complex_tensor)[0], + real[0], + msg="Real parts are not aligned.", + ) + self.assertEqual( + jnp.imag(complex_tensor)[0], + imag[0], + msg="Imaginary parts are not aligned.", + ) + def test_stft_invalid_input_type(self): x = jnp.array([1, 2, 3, 4]) sequence_length = 2 @@ -104,21 +137,12 @@ def test_stft_invalid_window_shape(self): with self.assertRaisesRegex(ValueError, "The shape of `window` must"): stft(x, sequence_length, sequence_stride, fft_length, window=window) - def test_invalid_not_float_get_complex_tensor_from_tuple_dtype(self): - real = jnp.array([[1, 2, 3]]) - imag = jnp.array([[4.0, 5.0, 6.0]]) - expected_message = "is not of type float" - with self.assertRaisesRegex(ValueError, expected_message): - _get_complex_tensor_from_tuple((real, imag)) - - def test_istft_invalid_window_shape2(self): + def test_istft_invalid_window_shape_2D_inputs(self): x = (jnp.array([[1.0, 2.0]]), jnp.array([[3.0, 4.0]])) sequence_length = 2 sequence_stride = 1 fft_length = 4 - incorrect_window = jnp.ones( - (sequence_length + 1,) - ) + incorrect_window = jnp.ones((sequence_length + 1,)) with self.assertRaisesRegex( ValueError, "The shape of `window` must be equal to" ): @@ -129,3 +153,16 @@ def test_istft_invalid_window_shape2(self): fft_length, window=incorrect_window, ) + + def test_istft_1D_inputs(self): + real = jnp.array([1.0, 2.0, 3.0, 4.0]) + imag = jnp.array([1.0, 2.0, 3.0, 4.0]) + x = (real, imag) + sequence_length = 3 + sequence_stride = 1 + fft_length = 4 + window = jnp.ones((sequence_length,)) + with self.assertRaisesRegex(ValueError, "Input `x` must have at least"): + istft( + x, sequence_length, sequence_stride, fft_length, window=window + ) From b5aabb6ecd2110cda688525dba99e66b44df360f Mon Sep 17 00:00:00 2001 From: Faisal Alsrheed <47912291+Faisal-Alsrheed@users.noreply.github.com> Date: Wed, 24 Apr 2024 07:56:38 +0000 Subject: [PATCH 09/10] fix --- keras/src/backend/jax/math_test.py | 27 +++++++++++++++++++++++---- 1 file changed, 23 insertions(+), 4 deletions(-) diff --git a/keras/src/backend/jax/math_test.py b/keras/src/backend/jax/math_test.py index 535d76ecb206..89ba55b9a3b5 100644 --- a/keras/src/backend/jax/math_test.py +++ b/keras/src/backend/jax/math_test.py @@ -43,13 +43,32 @@ def test_qr_invalid_mode(self): ): qr(x, mode=invalid_mode) - def test_get_complex_tensor_from_tuple_valid_input(self): + def test_get_complex_tensor_from_tuple_creates_complex_object(self): real = jnp.array([[1.0, 2.0, 3.0]]) imag = jnp.array([[4.0, 5.0, 6.0]]) complex_tensor = _get_complex_tensor_from_tuple((real, imag)) - self.assertTrue(jnp.iscomplexobj(complex_tensor)) - self.assertTrue(jnp.array_equal(jnp.real(complex_tensor), real)) - self.assertTrue(jnp.array_equal(jnp.imag(complex_tensor), imag)) + self.assertTrue( + jnp.iscomplexobj(complex_tensor), + "The output should be a complex object.", + ) + + def test_get_complex_tensor_from_tuple_correct_real_part(self): + real = jnp.array([[1.0, 2.0, 3.0]]) + imag = jnp.array([[4.0, 5.0, 6.0]]) + complex_tensor = _get_complex_tensor_from_tuple((real, imag)) + self.assertTrue( + jnp.array_equal(jnp.real(complex_tensor), real), + "The real parts should match.", + ) + + def test_get_complex_tensor_from_tuple_correct_imaginary_part(self): + real = jnp.array([[1.0, 2.0, 3.0]]) + imag = jnp.array([[4.0, 5.0, 6.0]]) + complex_tensor = _get_complex_tensor_from_tuple((real, imag)) + self.assertTrue( + jnp.array_equal(jnp.imag(complex_tensor), imag), + "The imaginary parts should match.", + ) def test_invalid_get_complex_tensor_from_tuple_input_type(self): with self.assertRaisesRegex(ValueError, "Input `x` should be a tuple"): From 997890e608ddafd5a510b19e4c9c5eb9825a0b61 Mon Sep 17 00:00:00 2001 From: Faisal Alsrheed <47912291+Faisal-Alsrheed@users.noreply.github.com> Date: Mon, 29 Apr 2024 08:11:07 +0000 Subject: [PATCH 10/10] Fix ValueError in istft function for inputs with less than 2 dimensions --- keras/src/backend/jax/math.py | 5 +- keras/src/backend/jax/math_test.py | 187 ----------------------------- keras/src/ops/linalg_test.py | 9 ++ keras/src/ops/math_test.py | 88 ++++++++++++++ 4 files changed, 101 insertions(+), 188 deletions(-) delete mode 100644 keras/src/backend/jax/math_test.py diff --git a/keras/src/backend/jax/math.py b/keras/src/backend/jax/math.py index 32e8224ae704..4119f744e1a3 100644 --- a/keras/src/backend/jax/math.py +++ b/keras/src/backend/jax/math.py @@ -205,7 +205,10 @@ def istft( dtype = jnp.real(x).dtype if len(x.shape) < 2: - raise ValueError("Input `x` must have at least 2 dimensions.") + raise ValueError( + f"Input `x` must have at least 2 dimensions. " + f"Received shape: {x.shape}" + ) expected_output_len = fft_length + sequence_stride * (x.shape[-2] - 1) l_pad = (fft_length - sequence_length) // 2 diff --git a/keras/src/backend/jax/math_test.py b/keras/src/backend/jax/math_test.py deleted file mode 100644 index 89ba55b9a3b5..000000000000 --- a/keras/src/backend/jax/math_test.py +++ /dev/null @@ -1,187 +0,0 @@ -import jax -import jax.numpy as jnp -import pytest - -from keras.src import backend -from keras.src import testing -from keras.src.backend.jax.math import _get_complex_tensor_from_tuple -from keras.src.backend.jax.math import istft -from keras.src.backend.jax.math import qr -from keras.src.backend.jax.math import segment_max -from keras.src.backend.jax.math import segment_sum -from keras.src.backend.jax.math import stft - - -@pytest.mark.skipif( - backend.backend() != "jax", reason="Testing Jax functions only" -) -class TestJaxMathErrors(testing.TestCase): - - def test_segment_sum_no_num_segments(self): - data = jnp.array([1, 2, 3, 4]) - segment_ids = jnp.array([0, 0, 1, 1]) - with self.assertRaisesRegex( - ValueError, - "Argument `num_segments` must be set when using the JAX backend.", - ): - segment_sum(data, segment_ids) - - def test_segment_max_no_num_segments(self): - data = jnp.array([1, 2, 3, 4]) - segment_ids = jnp.array([0, 0, 1, 1]) - with self.assertRaisesRegex( - ValueError, - "Argument `num_segments` must be set when using the JAX backend.", - ): - segment_max(data, segment_ids) - - def test_qr_invalid_mode(self): - x = jnp.array([[1, 2], [3, 4]]) - invalid_mode = "invalid_mode" - with self.assertRaisesRegex( - ValueError, "Expected one of {'reduced', 'complete'}." - ): - qr(x, mode=invalid_mode) - - def test_get_complex_tensor_from_tuple_creates_complex_object(self): - real = jnp.array([[1.0, 2.0, 3.0]]) - imag = jnp.array([[4.0, 5.0, 6.0]]) - complex_tensor = _get_complex_tensor_from_tuple((real, imag)) - self.assertTrue( - jnp.iscomplexobj(complex_tensor), - "The output should be a complex object.", - ) - - def test_get_complex_tensor_from_tuple_correct_real_part(self): - real = jnp.array([[1.0, 2.0, 3.0]]) - imag = jnp.array([[4.0, 5.0, 6.0]]) - complex_tensor = _get_complex_tensor_from_tuple((real, imag)) - self.assertTrue( - jnp.array_equal(jnp.real(complex_tensor), real), - "The real parts should match.", - ) - - def test_get_complex_tensor_from_tuple_correct_imaginary_part(self): - real = jnp.array([[1.0, 2.0, 3.0]]) - imag = jnp.array([[4.0, 5.0, 6.0]]) - complex_tensor = _get_complex_tensor_from_tuple((real, imag)) - self.assertTrue( - jnp.array_equal(jnp.imag(complex_tensor), imag), - "The imaginary parts should match.", - ) - - def test_invalid_get_complex_tensor_from_tuple_input_type(self): - with self.assertRaisesRegex(ValueError, "Input `x` should be a tuple"): - _get_complex_tensor_from_tuple(jnp.array([1.0, 2.0, 3.0])) - - def test_invalid_get_complex_tensor_from_tuple_input_length(self): - with self.assertRaisesRegex(ValueError, "Input `x` should be a tuple"): - _get_complex_tensor_from_tuple( - ( - jnp.array([1.0, 2.0, 3.0]), - jnp.array([4.0, 5.0, 6.0]), - jnp.array([7.0, 8.0, 9.0]), - ) - ) - - def test_get_complex_tensor_from_tuple_mismatched_shapes(self): - real = jnp.array([1.0, 2.0, 3.0]) - imag = jnp.array([4.0, 5.0]) - with self.assertRaisesRegex(ValueError, "Both the real and imaginary"): - _get_complex_tensor_from_tuple((real, imag)) - - def test_invalid_not_float_get_complex_tensor_from_tuple_dtype(self): - real = jnp.array([[1, 2, 3]]) - imag = jnp.array([[4.0, 5.0, 6.0]]) - expected_message = "is not of type float" - with self.assertRaisesRegex(ValueError, expected_message): - _get_complex_tensor_from_tuple((real, imag)) - - def test_get_complex_tensor_from_tuple_complex_tensor_creation(self): - real = jnp.array([1.0, 2.0]) - imag = jnp.array([3.0, 4.0]) - expected_complex = jax.lax.complex(real, imag) - result = _get_complex_tensor_from_tuple((real, imag)) - self.assertTrue( - jnp.array_equal(result, expected_complex), - msg="Complex tensor not created correctly.", - ) - - def test_get_complex_tensor_from_tuple_output_completeness(self): - real = jnp.array([1.0, 2.0]) - imag = jnp.array([3.0, 4.0]) - complex_tensor = _get_complex_tensor_from_tuple((real, imag)) - self.assertEqual( - jnp.real(complex_tensor)[0], - real[0], - msg="Real parts are not aligned.", - ) - self.assertEqual( - jnp.imag(complex_tensor)[0], - imag[0], - msg="Imaginary parts are not aligned.", - ) - - def test_stft_invalid_input_type(self): - x = jnp.array([1, 2, 3, 4]) - sequence_length = 2 - sequence_stride = 1 - fft_length = 4 - with self.assertRaisesRegex(TypeError, "`float32` or `float64`"): - stft(x, sequence_length, sequence_stride, fft_length) - - def test_invalid_fft_length(self): - x = jnp.array([1.0, 2.0, 3.0, 4.0]) - sequence_length = 4 - sequence_stride = 1 - fft_length = 2 - with self.assertRaisesRegex(ValueError, "`fft_length` must equal or"): - stft(x, sequence_length, sequence_stride, fft_length) - - def test_stft_invalid_window(self): - x = jnp.array([1.0, 2.0, 3.0, 4.0]) - sequence_length = 2 - sequence_stride = 1 - fft_length = 4 - window = "invalid_window" - with self.assertRaisesRegex(ValueError, "If a string is passed to"): - stft(x, sequence_length, sequence_stride, fft_length, window=window) - - def test_stft_invalid_window_shape(self): - x = jnp.array([1.0, 2.0, 3.0, 4.0]) - sequence_length = 2 - sequence_stride = 1 - fft_length = 4 - window = jnp.ones((sequence_length + 1)) - with self.assertRaisesRegex(ValueError, "The shape of `window` must"): - stft(x, sequence_length, sequence_stride, fft_length, window=window) - - def test_istft_invalid_window_shape_2D_inputs(self): - x = (jnp.array([[1.0, 2.0]]), jnp.array([[3.0, 4.0]])) - sequence_length = 2 - sequence_stride = 1 - fft_length = 4 - incorrect_window = jnp.ones((sequence_length + 1,)) - with self.assertRaisesRegex( - ValueError, "The shape of `window` must be equal to" - ): - istft( - x, - sequence_length, - sequence_stride, - fft_length, - window=incorrect_window, - ) - - def test_istft_1D_inputs(self): - real = jnp.array([1.0, 2.0, 3.0, 4.0]) - imag = jnp.array([1.0, 2.0, 3.0, 4.0]) - x = (real, imag) - sequence_length = 3 - sequence_stride = 1 - fft_length = 4 - window = jnp.ones((sequence_length,)) - with self.assertRaisesRegex(ValueError, "Input `x` must have at least"): - istft( - x, sequence_length, sequence_stride, fft_length, window=window - ) diff --git a/keras/src/ops/linalg_test.py b/keras/src/ops/linalg_test.py index e1f0decf64b0..a2fd0c61aad0 100644 --- a/keras/src/ops/linalg_test.py +++ b/keras/src/ops/linalg_test.py @@ -101,6 +101,15 @@ def test_qr(self): self.assertEqual(q.shape, qref_shape) self.assertEqual(r.shape, rref_shape) + def test_qr_invalid_mode(self): + # backend agnostic error message + x = np.array([[1, 2], [3, 4]]) + invalid_mode = "invalid_mode" + with self.assertRaisesRegex( + ValueError, "Expected one of {'reduced', 'complete'}." + ): + linalg.qr(x, mode=invalid_mode) + def test_solve(self): a = KerasTensor([None, 20, 20]) b = KerasTensor([None, 20, 5]) diff --git a/keras/src/ops/math_test.py b/keras/src/ops/math_test.py index 60db9fc70f6c..86e3c70a78ee 100644 --- a/keras/src/ops/math_test.py +++ b/keras/src/ops/math_test.py @@ -1,5 +1,6 @@ import math +import jax.numpy as jnp import numpy as np import pytest import scipy.signal @@ -1256,3 +1257,90 @@ def test_undefined_fft_length_and_last_dimension(self): expected_shape = real_part.shape[:-1] + (None,) self.assertEqual(output_spec.shape, expected_shape) + + +class TestMathErrors(testing.TestCase): + + @pytest.mark.skipif( + backend.backend() != "jax", reason="Testing Jax errors only" + ) + def test_segment_sum_no_num_segments(self): + data = jnp.array([1, 2, 3, 4]) + segment_ids = jnp.array([0, 0, 1, 1]) + with self.assertRaisesRegex( + ValueError, + "Argument `num_segments` must be set when using the JAX backend.", + ): + kmath.segment_sum(data, segment_ids) + + @pytest.mark.skipif( + backend.backend() != "jax", reason="Testing Jax errors only" + ) + def test_segment_max_no_num_segments(self): + data = jnp.array([1, 2, 3, 4]) + segment_ids = jnp.array([0, 0, 1, 1]) + with self.assertRaisesRegex( + ValueError, + "Argument `num_segments` must be set when using the JAX backend.", + ): + kmath.segment_max(data, segment_ids) + + def test_stft_invalid_input_type(self): + # backend agnostic error message + x = np.array([1, 2, 3, 4]) + sequence_length = 2 + sequence_stride = 1 + fft_length = 4 + with self.assertRaisesRegex(TypeError, "`float32` or `float64`"): + kmath.stft(x, sequence_length, sequence_stride, fft_length) + + def test_invalid_fft_length(self): + # backend agnostic error message + x = np.array([1.0, 2.0, 3.0, 4.0]) + sequence_length = 4 + sequence_stride = 1 + fft_length = 2 + with self.assertRaisesRegex(ValueError, "`fft_length` must equal or"): + kmath.stft(x, sequence_length, sequence_stride, fft_length) + + def test_stft_invalid_window(self): + # backend agnostic error message + x = np.array([1.0, 2.0, 3.0, 4.0]) + sequence_length = 2 + sequence_stride = 1 + fft_length = 4 + window = "invalid_window" + with self.assertRaisesRegex(ValueError, "If a string is passed to"): + kmath.stft( + x, sequence_length, sequence_stride, fft_length, window=window + ) + + def test_stft_invalid_window_shape(self): + # backend agnostic error message + x = np.array([1.0, 2.0, 3.0, 4.0]) + sequence_length = 2 + sequence_stride = 1 + fft_length = 4 + window = np.ones((sequence_length + 1)) + with self.assertRaisesRegex(ValueError, "The shape of `window` must"): + kmath.stft( + x, sequence_length, sequence_stride, fft_length, window=window + ) + + def test_istft_invalid_window_shape_2D_inputs(self): + # backend agnostic error message + x = (np.array([[1.0, 2.0]]), np.array([[3.0, 4.0]])) + sequence_length = 2 + sequence_stride = 1 + fft_length = 4 + incorrect_window = np.ones((sequence_length + 1,)) + with self.assertRaisesRegex( + ValueError, "The shape of `window` must be equal to" + ): + kmath.istft( + x, + sequence_length, + sequence_stride, + fft_length, + window=incorrect_window, + )