From 620b952884aae0d8026514204f6c1ae6ff54b081 Mon Sep 17 00:00:00 2001 From: Calvin McCarter <77687912+calvinmccarter-at-lightmatter@users.noreply.github.com> Date: Thu, 13 May 2021 11:35:05 -0500 Subject: [PATCH] Explicit padding for const shape ConvTranspose (#1513) * explicit padding for const shape ConvTranspose Signed-off-by: Calvin McCarter * pylint fix Signed-off-by: Calvin McCarter * test_backend graph_validator fix Signed-off-by: Calvin McCarter * keep output_shape if input_dims unknown Signed-off-by: Calvin McCarter * cleanup Signed-off-by: Calvin McCarter * remove unused variable Signed-off-by: Calvin McCarter Co-authored-by: TomWildenhain-Microsoft <67606533+TomWildenhain-Microsoft@users.noreply.github.com> Co-authored-by: Guenther Schmuelling --- tests/test_backend.py | 2 +- tf2onnx/onnx_opset/nn.py | 33 +++++++++++++++++++++++++++++---- 2 files changed, 30 insertions(+), 5 deletions(-) diff --git a/tests/test_backend.py b/tests/test_backend.py index b33760b95..4fa7998c1 100644 --- a/tests/test_backend.py +++ b/tests/test_backend.py @@ -4706,7 +4706,7 @@ def func(filter_val, out_backprop_val, batch_dim): def graph_validator(g): for n in g.get_nodes(): if n.type == 'ConvTranspose': - return "output_shape" in n.attr + return "pads" in n.attr or "output_shape" in n.attr return False self._run_test_case(func, [_OUTPUT], {_INPUT: filters_val, _INPUT1: out_backprop_val, _INPUT2: batch_dim_val}, graph_validator=graph_validator) diff --git a/tf2onnx/onnx_opset/nn.py b/tf2onnx/onnx_opset/nn.py index 3026e866a..36cc09211 100644 --- a/tf2onnx/onnx_opset/nn.py +++ b/tf2onnx/onnx_opset/nn.py @@ -437,10 +437,9 @@ def version_1(cls, ctx, node, **kwargs): input_dims = input_shape[2:2+spatial] output_shape_orig = node.output_shapes - # ouput_shape is explicitly specified here, in this case pads values are auto generated/calculated. + # output_shape is explicitly specified here and then converted to explicit pads. output_shape = get_shape_from_const_or_concat(ctx, node.inputs[0]) if output_shape is not None: - #output_shape = ctx.get_shape(node.output[0]) if is_channels_last(node): new_output_shape = [output_shape[1], output_shape[2]] if spatial == 3: @@ -453,8 +452,34 @@ def version_1(cls, ctx, node, **kwargs): utils.make_sure(new_output_shape.count(-1) <= 0, "output dims need to be known") utils.make_sure(all(new_output_shape[i] >= input_dims[i] for i in range(spatial)), "output dims cannot be smaller than input dims.") - - node.set_attr("output_shape", new_output_shape) + if -1 in input_dims: + node.set_attr("output_shape", new_output_shape) + else: + if "strides" in node.attr: + strides = parse_dims_attr(node, node.get_attr("strides").ints, spatial) + else: + strides = [1] * spatial + if "dilations" in node.attr: + dilations = parse_dims_attr(node, node.get_attr("dilations").ints, spatial) + else: + dilations = [1] * spatial + if "output_padding" in node.attr: + output_padding = parse_dims_attr(node, node.get_attr("output_padding").ints, spatial) + else: + output_padding = [0] * spatial + kernel_shape = parse_dims_attr(node, node.get_attr("kernel_shape").ints, spatial) + total_padding = [-1] * spatial + pads = [1] * (spatial * 2) + for i in range(spatial): + total_padding[i] = (strides[i] * (input_dims[i] - 1) + output_padding[i] + + ((kernel_shape[i] - 1) * dilations[i] + 1) + - new_output_shape[i]) + start_i = i + end_i = i + spatial + pads[start_i] = int(total_padding[i] / 2) + pads[end_i] = total_padding[i] - pads[start_i] + node.set_attr("pads", pads) + node.set_attr("auto_pad", "NOTSET") else: utils.make_sure(ctx.opset >= 10, "Opset 10 needed for Conv Backprop Input with non-constant shape") strides = parse_dims_attr(node, node.get_attr('strides').ints, spatial)