diff --git a/keras_core/ops/operation_utils.py b/keras_core/ops/operation_utils.py index 7a4d3a0ea..bfac73ec4 100644 --- a/keras_core/ops/operation_utils.py +++ b/keras_core/ops/operation_utils.py @@ -222,20 +222,18 @@ def reduce_shape(shape, axis=None, keepdims=False): shape = list(shape) if axis is None: if keepdims: - output_shape = [1 for _ in range(shape)] + return tuple([1 for _ in shape]) else: - output_shape = [] - return output_shape + return tuple([]) if keepdims: for ax in axis: shape[ax] = 1 - return shape + return tuple(shape) else: - for ax in axis: - shape[ax] = -1 - output_shape = list(filter((-1).__ne__, shape)) - return output_shape + for ax in sorted(axis, reverse=True): + del shape[ax] + return tuple(shape) @keras_core_export("keras_core.utils.get_source_inputs") diff --git a/keras_core/ops/operation_utils_test.py b/keras_core/ops/operation_utils_test.py index 06e194281..e0ec7c6fd 100644 --- a/keras_core/ops/operation_utils_test.py +++ b/keras_core/ops/operation_utils_test.py @@ -17,3 +17,170 @@ def test_get_source_inputs(self): def test_get_source_inputs_return_input_tensor(self): inputs = input_layer.Input(shape=(10,)) self.assertIs(operation_utils.get_source_inputs(inputs)[0], inputs) + + def test_compute_pooling_output_shape(self): + input_shape = (1, 4, 4, 1) + pool_size = (2, 2) + strides = (2, 2) + output_shape = operation_utils.compute_pooling_output_shape( + input_shape, pool_size, strides + ) + expected_output_shape = (1, 2, 2, 1) + self.assertEqual(output_shape, expected_output_shape) + + def test_compute_pooling_output_shape_with_none(self): + input_shape = (None, 4, 4, 1) + pool_size = (2, 2) + strides = (2, 2) + output_shape = operation_utils.compute_pooling_output_shape( + input_shape, pool_size, strides + ) + expected_output_shape = (None, 2, 2, 1) + self.assertEqual(output_shape, expected_output_shape) + + def test_compute_pooling_output_shape_valid_padding(self): + input_shape = (1, 4, 4, 1) + pool_size = (2, 2) + strides = (2, 2) + output_shape = operation_utils.compute_pooling_output_shape( + input_shape, pool_size, strides, padding="valid" + ) + self.assertEqual(output_shape, (1, 2, 2, 1)) + + def test_compute_pooling_output_shape_channels_last(self): + input_shape = (1, 4, 4, 3) + pool_size = (2, 2) + strides = (2, 2) + output_shape = operation_utils.compute_pooling_output_shape( + input_shape, + pool_size, + strides, + padding="valid", + data_format="channels_last", + ) + self.assertEqual(output_shape, (1, 2, 2, 3)) + + def test_compute_pooling_output_shape_same_padding_stride1(self): + input_shape = (1, 4, 4, 3) + pool_size = (2, 2) + strides = (1, 1) + output_shape = operation_utils.compute_pooling_output_shape( + input_shape, + pool_size, + strides, + padding="same", + data_format="channels_last", + ) + self.assertEqual(output_shape, (1, 4, 4, 3)) + + def test_compute_conv_output_shape(self): + input_shape = (1, 4, 4, 1) + filters = 1 + kernel_size = (3, 3) + strides = (1, 1) + output_shape = operation_utils.compute_conv_output_shape( + input_shape, filters, kernel_size, strides + ) + expected_output_shape = (1, 2, 2, 1) + self.assertEqual(output_shape, expected_output_shape) + + def test_compute_conv_output_shape_with_none(self): + input_shape = (None, 4, 4, 1) + kernel_size = (3, 3) + filters = 1 + strides = (1, 1) + output_shape = operation_utils.compute_conv_output_shape( + input_shape, filters, kernel_size, strides + ) + expected_output_shape = (None, 2, 2, 1) + self.assertEqual(output_shape, expected_output_shape) + + def test_compute_conv_output_shape_valid_padding(self): + input_shape = (1, 4, 4, 1) + kernel_size = (3, 3) + filters = 1 + strides = (2, 2) + output_shape = operation_utils.compute_conv_output_shape( + input_shape, filters, kernel_size, strides, padding="valid" + ) + self.assertEqual(output_shape, (1, 1, 1, 1)) + + def test_compute_conv_output_shape_channels_last(self): + input_shape = (1, 4, 4, 3) + kernel_size = (3, 3) + filters = 3 + strides = (2, 2) + output_shape = operation_utils.compute_conv_output_shape( + input_shape, + filters, + kernel_size, + strides, + padding="valid", + data_format="channels_last", + ) + self.assertEqual(output_shape, (1, 1, 1, 3)) + + def test_compute_conv_output_shape_same_padding_stride1(self): + input_shape = (1, 4, 4, 3) + kernel_size = (3, 3) + filters = 3 + strides = (1, 1) + output_shape = operation_utils.compute_conv_output_shape( + input_shape, + filters, + kernel_size, + strides, + padding="same", + data_format="channels_last", + ) + self.assertEqual(output_shape, (1, 4, 4, 3)) + + def test_compute_reshape_output_shape(self): + input_shape = (1, 4, 4, 1) + target_shape = (16, 1) + output_shape = operation_utils.compute_reshape_output_shape( + input_shape, new_shape=target_shape, new_shape_arg_name="New shape" + ) + self.assertEqual(output_shape, target_shape) + + def test_reduce_shape_no_axes_no_keepdims(self): + input_shape = (1, 4, 4, 1) + output_shape = operation_utils.reduce_shape(input_shape) + expected_output_shape = () + self.assertEqual(output_shape, expected_output_shape) + + def test_reduce_shape_no_axes_with_keepdims(self): + input_shape = (1, 4, 4, 1) + output_shape = operation_utils.reduce_shape(input_shape, keepdims=True) + expected_output_shape = (1, 1, 1, 1) + self.assertEqual(output_shape, expected_output_shape) + + def test_reduce_shape_single_axis_no_keepdims(self): + input_shape = (1, 4, 4, 1) + axes = [1] + output_shape = operation_utils.reduce_shape(input_shape, axes) + expected_output_shape = (1, 4, 1) + self.assertEqual(output_shape, expected_output_shape) + + def test_reduce_shape_single_axis_with_keepdims(self): + input_shape = (1, 4, 4, 1) + axes = [1] + output_shape = operation_utils.reduce_shape( + input_shape, axes, keepdims=True + ) + expected_output_shape = (1, 1, 4, 1) + self.assertEqual(output_shape, expected_output_shape) + + def test_reduce_shape_multiple_axes_no_keepdims(self): + input_shape = (1, 4, 4, 1) + axes = [1, 2] + output_shape = operation_utils.reduce_shape(input_shape, axes) + expected_output_shape = (1, 1) + self.assertEqual(output_shape, expected_output_shape) + + def test_reduce_shape_out_of_order_axes_no_keepdims(self): + input_shape = (1, 4, 4, 1) + axes = [2, 1] + output_shape = operation_utils.reduce_shape(input_shape, axes) + expected_output_shape = (1, 1) + self.assertEqual(output_shape, expected_output_shape)