Skip to content

Commit

Permalink
Explicit padding for const shape ConvTranspose (#1513)
Browse files Browse the repository at this point in the history
* explicit padding for const shape ConvTranspose

Signed-off-by: Calvin McCarter <calvin@lightmatter.co>

* pylint fix

Signed-off-by: Calvin McCarter <calvin@lightmatter.co>

* test_backend graph_validator fix

Signed-off-by: Calvin McCarter <calvin@lightmatter.co>

* keep output_shape if input_dims unknown

Signed-off-by: Calvin McCarter <calvin@lightmatter.co>

* cleanup

Signed-off-by: Calvin McCarter <calvin@lightmatter.co>

* remove unused variable

Signed-off-by: Calvin McCarter <calvin@lightmatter.co>

Co-authored-by: TomWildenhain-Microsoft <67606533+TomWildenhain-Microsoft@users.noreply.github.com>
Co-authored-by: Guenther Schmuelling <guschmue@microsoft.com>
  • Loading branch information
3 people authored May 13, 2021
1 parent 46cb038 commit 620b952
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 5 deletions.
2 changes: 1 addition & 1 deletion tests/test_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -4706,7 +4706,7 @@ def func(filter_val, out_backprop_val, batch_dim):
def graph_validator(g):
for n in g.get_nodes():
if n.type == 'ConvTranspose':
return "output_shape" in n.attr
return "pads" in n.attr or "output_shape" in n.attr
return False
self._run_test_case(func, [_OUTPUT], {_INPUT: filters_val, _INPUT1: out_backprop_val, _INPUT2: batch_dim_val},
graph_validator=graph_validator)
Expand Down
33 changes: 29 additions & 4 deletions tf2onnx/onnx_opset/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -437,10 +437,9 @@ def version_1(cls, ctx, node, **kwargs):
input_dims = input_shape[2:2+spatial]
output_shape_orig = node.output_shapes

# ouput_shape is explicitly specified here, in this case pads values are auto generated/calculated.
# output_shape is explicitly specified here and then converted to explicit pads.
output_shape = get_shape_from_const_or_concat(ctx, node.inputs[0])
if output_shape is not None:
#output_shape = ctx.get_shape(node.output[0])
if is_channels_last(node):
new_output_shape = [output_shape[1], output_shape[2]]
if spatial == 3:
Expand All @@ -453,8 +452,34 @@ def version_1(cls, ctx, node, **kwargs):
utils.make_sure(new_output_shape.count(-1) <= 0, "output dims need to be known")
utils.make_sure(all(new_output_shape[i] >= input_dims[i] for i in range(spatial)),
"output dims cannot be smaller than input dims.")

node.set_attr("output_shape", new_output_shape)
if -1 in input_dims:
node.set_attr("output_shape", new_output_shape)
else:
if "strides" in node.attr:
strides = parse_dims_attr(node, node.get_attr("strides").ints, spatial)
else:
strides = [1] * spatial
if "dilations" in node.attr:
dilations = parse_dims_attr(node, node.get_attr("dilations").ints, spatial)
else:
dilations = [1] * spatial
if "output_padding" in node.attr:
output_padding = parse_dims_attr(node, node.get_attr("output_padding").ints, spatial)
else:
output_padding = [0] * spatial
kernel_shape = parse_dims_attr(node, node.get_attr("kernel_shape").ints, spatial)
total_padding = [-1] * spatial
pads = [1] * (spatial * 2)
for i in range(spatial):
total_padding[i] = (strides[i] * (input_dims[i] - 1) + output_padding[i]
+ ((kernel_shape[i] - 1) * dilations[i] + 1)
- new_output_shape[i])
start_i = i
end_i = i + spatial
pads[start_i] = int(total_padding[i] / 2)
pads[end_i] = total_padding[i] - pads[start_i]
node.set_attr("pads", pads)
node.set_attr("auto_pad", "NOTSET")
else:
utils.make_sure(ctx.opset >= 10, "Opset 10 needed for Conv Backprop Input with non-constant shape")
strides = parse_dims_attr(node, node.get_attr('strides').ints, spatial)
Expand Down

0 comments on commit 620b952

Please sign in to comment.