Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Explicit padding for const shape ConvTranspose #1513

Merged
2 changes: 1 addition & 1 deletion tests/test_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -4706,7 +4706,7 @@ def func(filter_val, out_backprop_val, batch_dim):
def graph_validator(g):
for n in g.get_nodes():
if n.type == 'ConvTranspose':
return "output_shape" in n.attr
return "pads" in n.attr or "output_shape" in n.attr
return False
self._run_test_case(func, [_OUTPUT], {_INPUT: filters_val, _INPUT1: out_backprop_val, _INPUT2: batch_dim_val},
graph_validator=graph_validator)
Expand Down
34 changes: 30 additions & 4 deletions tf2onnx/onnx_opset/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -435,12 +435,12 @@ def version_1(cls, ctx, node, **kwargs):
input_dims = input_shape[1:1+spatial]
else:
input_dims = input_shape[2:2+spatial]
input_dims_known = -1 not in input_dims
TomWildenhain-Microsoft marked this conversation as resolved.
Show resolved Hide resolved
output_shape_orig = node.output_shapes

# ouput_shape is explicitly specified here, in this case pads values are auto generated/calculated.
# output_shape is explicitly specified here and then converted to explicit pads.
output_shape = get_shape_from_const_or_concat(ctx, node.inputs[0])
if output_shape is not None:
#output_shape = ctx.get_shape(node.output[0])
if is_channels_last(node):
new_output_shape = [output_shape[1], output_shape[2]]
if spatial == 3:
Expand All @@ -453,8 +453,34 @@ def version_1(cls, ctx, node, **kwargs):
utils.make_sure(new_output_shape.count(-1) <= 0, "output dims need to be known")
utils.make_sure(all(new_output_shape[i] >= input_dims[i] for i in range(spatial)),
"output dims cannot be smaller than input dims.")

node.set_attr("output_shape", new_output_shape)
if -1 in input_dims:
node.set_attr("output_shape", new_output_shape)
else:
if "strides" in node.attr:
strides = parse_dims_attr(node, node.get_attr("strides").ints, spatial)
else:
strides = [1] * spatial
if "dilations" in node.attr:
dilations = parse_dims_attr(node, node.get_attr("dilations").ints, spatial)
else:
dilations = [1] * spatial
if "output_padding" in node.attr:
output_padding = parse_dims_attr(node, node.get_attr("output_padding").ints, spatial)
else:
output_padding = [0] * spatial
kernel_shape = parse_dims_attr(node, node.get_attr("kernel_shape").ints, spatial)
total_padding = [-1] * spatial
pads = [1] * (spatial * 2)
for i in range(spatial):
total_padding[i] = (strides[i] * (input_dims[i] - 1) + output_padding[i]
+ ((kernel_shape[i] - 1) * dilations[i] + 1)
- new_output_shape[i])
start_i = i
end_i = i + spatial
pads[start_i] = int(total_padding[i] / 2)
pads[end_i] = total_padding[i] - pads[start_i]
node.set_attr("pads", pads)
node.set_attr("auto_pad", "NOTSET")
else:
utils.make_sure(ctx.opset >= 10, "Opset 10 needed for Conv Backprop Input with non-constant shape")
strides = parse_dims_attr(node, node.get_attr('strides').ints, spatial)
Expand Down