From 5949a7668311f0458bae2dfb583056687a9741dd Mon Sep 17 00:00:00 2001 From: Francois Chollet Date: Tue, 5 Sep 2023 16:15:08 -0700 Subject: [PATCH] Fix shape inference bug for conv2dtranspose --- keras_core/backend/common/backend_utils.py | 3 +++ .../convolutional/conv_transpose_test.py | 20 +++++++++++++++++++ 2 files changed, 23 insertions(+) diff --git a/keras_core/backend/common/backend_utils.py b/keras_core/backend/common/backend_utils.py index e27d9b28f..ace6d9ccf 100644 --- a/keras_core/backend/common/backend_utils.py +++ b/keras_core/backend/common/backend_utils.py @@ -187,6 +187,9 @@ def compute_conv_transpose_padding_args_for_torch( def _get_output_shape_given_tf_padding( input_size, kernel_size, strides, padding, output_padding, dilation_rate ): + if input_size is None: + return None + assert padding.lower() in {"valid", "same"} kernel_size = (kernel_size - 1) * dilation_rate + 1 diff --git a/keras_core/layers/convolutional/conv_transpose_test.py b/keras_core/layers/convolutional/conv_transpose_test.py index 3178adcc2..831390bfd 100644 --- a/keras_core/layers/convolutional/conv_transpose_test.py +++ b/keras_core/layers/convolutional/conv_transpose_test.py @@ -835,3 +835,23 @@ def test_conv1d_transpose_consistency( # Compare results kc_res = kc_layer(input) self.assertAllClose(expected_res, kc_res, atol=1e-5) + + @parameterized.product( + kernel_size=list(range(1, 5)), + strides=list(range(1, 5)), + padding=["same", "valid"], + output_padding=[None] + list(range(1, 5)), + ) + def test_shape_inference_static_unknown_shape( + self, kernel_size, strides, padding, output_padding + ): + x = layers.Input(shape=(None, None, 3)) + x = layers.Conv2DTranspose( + filters=2, + kernel_size=kernel_size, + strides=strides, + padding=padding, + output_padding=output_padding, + dilation_rate=1, + )(x) + self.assertEqual(x.shape, (None, None, None, 2))