From c9843abccbe64445335b98b7f81fce36fb6d3ae9 Mon Sep 17 00:00:00 2001 From: Luca Pizzini Date: Sat, 27 Apr 2024 16:31:44 +0200 Subject: [PATCH] test(trainers): add test_errors implementation for ArrayDataAdapter class --- .../data_adapters/array_data_adapter_test.py | 57 ++++++++++++++++++- 1 file changed, 55 insertions(+), 2 deletions(-) diff --git a/keras/src/trainers/data_adapters/array_data_adapter_test.py b/keras/src/trainers/data_adapters/array_data_adapter_test.py index 80b4462e407..a61a904240e 100644 --- a/keras/src/trainers/data_adapters/array_data_adapter_test.py +++ b/keras/src/trainers/data_adapters/array_data_adapter_test.py @@ -244,5 +244,58 @@ def test_class_weights(self, target_encoding): self.assertAllClose(bw, [0.1, 0.2, 0.3, 0.4]) def test_errors(self): - # TODO - pass + x = np.random.random((34, 1)) + y = np.random.random((34, 3)) + sw = np.random.random((34,)) + cw = { + 0: 0.1, + 1: 0.2, + 2: 0.3, + 3: 0.4, + } + + with self.assertRaisesRegex( + ValueError, "Expected all elements of `x` to be array-like" + ): + array_data_adapter.ArrayDataAdapter(x="Invalid") + with self.assertRaisesRegex( + ValueError, "Expected all elements of `x` to be array-like" + ): + array_data_adapter.ArrayDataAdapter(x=x, y="Invalid") + with self.assertRaisesRegex( + ValueError, "Expected all elements of `x` to be array-like" + ): + array_data_adapter.ArrayDataAdapter( + x=x, y=y, sample_weight="Invalid" + ) + + with self.assertRaisesRegex( + ValueError, "You cannot `class_weight` and `sample_weight`" + ): + array_data_adapter.ArrayDataAdapter( + x=x, y=y, sample_weight=sw, class_weight=cw + ) + + nested_y = ({"x": x, "y": y},) + with self.assertRaisesRegex( + ValueError, "You should provide one `sample_weight` array per" + ): + array_data_adapter.ArrayDataAdapter( + x=x, y=nested_y, sample_weight=[] + ) + + tensor_sw = self.make_array("tf", (34, 2), "int32") + with self.assertRaisesRegex( + ValueError, "For a model with multiple outputs, when providing" + ): + array_data_adapter.ArrayDataAdapter( + x=x, y=nested_y, sample_weight=tensor_sw + ) + + with self.assertRaisesRegex( + ValueError, + "`class_weight` is only supported for Models with a single", + ): + array_data_adapter.ArrayDataAdapter( + x=x, y=nested_y, class_weight=cw + )