From 128e6f081019adee80d5355a7918d6da7b8742fe Mon Sep 17 00:00:00 2001 From: Tom Wildenhain Date: Fri, 16 Apr 2021 01:17:16 -0400 Subject: [PATCH] Add support for Explicit padding Signed-off-by: Tom Wildenhain --- tests/test_backend.py | 10 ++++++++++ tf2onnx/onnx_opset/nn.py | 9 +++++++++ 2 files changed, 19 insertions(+) diff --git a/tests/test_backend.py b/tests/test_backend.py index 04fd1626b..83cc8d104 100644 --- a/tests/test_backend.py +++ b/tests/test_backend.py @@ -438,6 +438,16 @@ def test_conv2d_6(self): kernel_val = np.arange(1, 1 + np.prod(kernel_shape)).astype("float32").reshape(kernel_shape) self._conv_test(x_val, kernel_val, strides=strides, padding="VALID", rtol=1e-05) + @check_tf_min_version("1.14", "tf 1.14 needed for explicit padding") + def test_conv2d_explicit_padding(self): + x_shape = [1, 35, 35, 288] + kernel_shape = [3, 3, 288, 384] + pads = [[0, 0], [1, 2], [3, 4], [0, 0]] + strides = [1, 1, 1, 1] + x_val = np.arange(1, 1 + np.prod(x_shape)).astype("float32").reshape(x_shape) + kernel_val = np.arange(1, 1 + np.prod(kernel_shape)).astype("float32").reshape(kernel_shape) + self._conv_test(x_val, kernel_val, strides=strides, padding=pads, rtol=1e-05) + def test_conv2d_dilation_same(self): x_shape = [1, 35, 35, 288] # NHWC kernel_shape = [3, 3, 288, 384] # [filter_height, filter_width, in_channels, out_channels] diff --git a/tf2onnx/onnx_opset/nn.py b/tf2onnx/onnx_opset/nn.py index 491fada6d..28e6883b2 100644 --- a/tf2onnx/onnx_opset/nn.py +++ b/tf2onnx/onnx_opset/nn.py @@ -249,6 +249,15 @@ def add_padding(ctx, node, kernel_shape, strides, dilations=None, spatial=2): node.set_attr("pads", pads) elif padding == "VALID": pass + elif padding == "EXPLICIT": + pads = node.get_attr_value("explicit_paddings") + start_pads = [] + end_pads = [] + d = 1 if is_channels_last(node) else 2 + for i in range(spatial): + start_pads.append(pads[(d + i) * 2]) + end_pads.append(pads[(d + i) * 2 + 1]) + node.set_attr("pads", start_pads + end_pads) else: raise ValueError("invalid padding value: {}".format(padding))