Skip to content

Commit

Permalink
Fix reduce_shape + Add tests to operation_utils_test.py (#826)
Browse files Browse the repository at this point in the history
* Increase-tests-in-ops/operation_utils_test.py

* Increase-tests-in-ops/operation_utils_test.py

* Fix: the shape reduction logic in `reduce_shape`

* fix reduce_shape Function + add Tests

* fix reduce_shape + add tests operation_utils
  • Loading branch information
Faisal-Alsrheed authored Aug 31, 2023
1 parent a0e12f7 commit 8aa504f
Show file tree
Hide file tree
Showing 2 changed files with 173 additions and 8 deletions.
14 changes: 6 additions & 8 deletions keras_core/ops/operation_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,20 +222,18 @@ def reduce_shape(shape, axis=None, keepdims=False):
shape = list(shape)
if axis is None:
if keepdims:
output_shape = [1 for _ in range(shape)]
return tuple([1 for _ in shape])
else:
output_shape = []
return output_shape
return tuple([])

if keepdims:
for ax in axis:
shape[ax] = 1
return shape
return tuple(shape)
else:
for ax in axis:
shape[ax] = -1
output_shape = list(filter((-1).__ne__, shape))
return output_shape
for ax in sorted(axis, reverse=True):
del shape[ax]
return tuple(shape)


@keras_core_export("keras_core.utils.get_source_inputs")
Expand Down
167 changes: 167 additions & 0 deletions keras_core/ops/operation_utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,3 +17,170 @@ def test_get_source_inputs(self):
def test_get_source_inputs_return_input_tensor(self):
inputs = input_layer.Input(shape=(10,))
self.assertIs(operation_utils.get_source_inputs(inputs)[0], inputs)

def test_compute_pooling_output_shape(self):
input_shape = (1, 4, 4, 1)
pool_size = (2, 2)
strides = (2, 2)
output_shape = operation_utils.compute_pooling_output_shape(
input_shape, pool_size, strides
)
expected_output_shape = (1, 2, 2, 1)
self.assertEqual(output_shape, expected_output_shape)

def test_compute_pooling_output_shape_with_none(self):
input_shape = (None, 4, 4, 1)
pool_size = (2, 2)
strides = (2, 2)
output_shape = operation_utils.compute_pooling_output_shape(
input_shape, pool_size, strides
)
expected_output_shape = (None, 2, 2, 1)
self.assertEqual(output_shape, expected_output_shape)

def test_compute_pooling_output_shape_valid_padding(self):
input_shape = (1, 4, 4, 1)
pool_size = (2, 2)
strides = (2, 2)
output_shape = operation_utils.compute_pooling_output_shape(
input_shape, pool_size, strides, padding="valid"
)
self.assertEqual(output_shape, (1, 2, 2, 1))

def test_compute_pooling_output_shape_channels_last(self):
input_shape = (1, 4, 4, 3)
pool_size = (2, 2)
strides = (2, 2)
output_shape = operation_utils.compute_pooling_output_shape(
input_shape,
pool_size,
strides,
padding="valid",
data_format="channels_last",
)
self.assertEqual(output_shape, (1, 2, 2, 3))

def test_compute_pooling_output_shape_same_padding_stride1(self):
input_shape = (1, 4, 4, 3)
pool_size = (2, 2)
strides = (1, 1)
output_shape = operation_utils.compute_pooling_output_shape(
input_shape,
pool_size,
strides,
padding="same",
data_format="channels_last",
)
self.assertEqual(output_shape, (1, 4, 4, 3))

def test_compute_conv_output_shape(self):
input_shape = (1, 4, 4, 1)
filters = 1
kernel_size = (3, 3)
strides = (1, 1)
output_shape = operation_utils.compute_conv_output_shape(
input_shape, filters, kernel_size, strides
)
expected_output_shape = (1, 2, 2, 1)
self.assertEqual(output_shape, expected_output_shape)

def test_compute_conv_output_shape_with_none(self):
input_shape = (None, 4, 4, 1)
kernel_size = (3, 3)
filters = 1
strides = (1, 1)
output_shape = operation_utils.compute_conv_output_shape(
input_shape, filters, kernel_size, strides
)
expected_output_shape = (None, 2, 2, 1)
self.assertEqual(output_shape, expected_output_shape)

def test_compute_conv_output_shape_valid_padding(self):
input_shape = (1, 4, 4, 1)
kernel_size = (3, 3)
filters = 1
strides = (2, 2)
output_shape = operation_utils.compute_conv_output_shape(
input_shape, filters, kernel_size, strides, padding="valid"
)
self.assertEqual(output_shape, (1, 1, 1, 1))

def test_compute_conv_output_shape_channels_last(self):
input_shape = (1, 4, 4, 3)
kernel_size = (3, 3)
filters = 3
strides = (2, 2)
output_shape = operation_utils.compute_conv_output_shape(
input_shape,
filters,
kernel_size,
strides,
padding="valid",
data_format="channels_last",
)
self.assertEqual(output_shape, (1, 1, 1, 3))

def test_compute_conv_output_shape_same_padding_stride1(self):
input_shape = (1, 4, 4, 3)
kernel_size = (3, 3)
filters = 3
strides = (1, 1)
output_shape = operation_utils.compute_conv_output_shape(
input_shape,
filters,
kernel_size,
strides,
padding="same",
data_format="channels_last",
)
self.assertEqual(output_shape, (1, 4, 4, 3))

def test_compute_reshape_output_shape(self):
input_shape = (1, 4, 4, 1)
target_shape = (16, 1)
output_shape = operation_utils.compute_reshape_output_shape(
input_shape, new_shape=target_shape, new_shape_arg_name="New shape"
)
self.assertEqual(output_shape, target_shape)

def test_reduce_shape_no_axes_no_keepdims(self):
input_shape = (1, 4, 4, 1)
output_shape = operation_utils.reduce_shape(input_shape)
expected_output_shape = ()
self.assertEqual(output_shape, expected_output_shape)

def test_reduce_shape_no_axes_with_keepdims(self):
input_shape = (1, 4, 4, 1)
output_shape = operation_utils.reduce_shape(input_shape, keepdims=True)
expected_output_shape = (1, 1, 1, 1)
self.assertEqual(output_shape, expected_output_shape)

def test_reduce_shape_single_axis_no_keepdims(self):
input_shape = (1, 4, 4, 1)
axes = [1]
output_shape = operation_utils.reduce_shape(input_shape, axes)
expected_output_shape = (1, 4, 1)
self.assertEqual(output_shape, expected_output_shape)

def test_reduce_shape_single_axis_with_keepdims(self):
input_shape = (1, 4, 4, 1)
axes = [1]
output_shape = operation_utils.reduce_shape(
input_shape, axes, keepdims=True
)
expected_output_shape = (1, 1, 4, 1)
self.assertEqual(output_shape, expected_output_shape)

def test_reduce_shape_multiple_axes_no_keepdims(self):
input_shape = (1, 4, 4, 1)
axes = [1, 2]
output_shape = operation_utils.reduce_shape(input_shape, axes)
expected_output_shape = (1, 1)
self.assertEqual(output_shape, expected_output_shape)

def test_reduce_shape_out_of_order_axes_no_keepdims(self):
input_shape = (1, 4, 4, 1)
axes = [2, 1]
output_shape = operation_utils.reduce_shape(input_shape, axes)
expected_output_shape = (1, 1)
self.assertEqual(output_shape, expected_output_shape)

0 comments on commit 8aa504f

Please sign in to comment.