Skip to content

Commit

Permalink
Fixes for dtype bugs
Browse files Browse the repository at this point in the history
Signed-off-by: Tom Wildenhain <tomwi@microsoft.com>
  • Loading branch information
TomWildenhain-Microsoft committed Apr 7, 2021
1 parent 9f50e9d commit ce67e78
Show file tree
Hide file tree
Showing 8 changed files with 13 additions and 6 deletions.
2 changes: 1 addition & 1 deletion tf2onnx/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -437,7 +437,7 @@ def maybe_cast_input(self, supported, type_map):
shape = self.graph.get_shape(name)
cast_node = self.graph.insert_new_node_on_input(
self, "Cast", name, to=tdtype)
self.graph.set_dtype(cast_node.output[0], [tdtype])
self.graph.set_dtype(cast_node.output[0], tdtype)
self.graph.set_shape(cast_node.output[0], shape)
did_cast = True
return did_cast
Expand Down
7 changes: 4 additions & 3 deletions tf2onnx/onnx_opset/controlflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -401,9 +401,10 @@ def version_7(cls, ctx, node, **kwargs):
if n.type in ["TensorListReserve", "TensorListResize"]:
# there is no equivalent step in onnx and we should remove it.
output_shape = None
output_dtype = n.get_attr_value("element_dtype")
if n.type == "TensorListReserve" and n.inputs[0].is_const() and not n.inputs[0].is_scalar():
output_shape = [-1] + n.inputs[0].get_tensor_value(as_list=True)
scan_outputs.append((idx, n, output_shape))
scan_outputs.append((idx, n, output_shape, output_dtype))
continue

# tensor arrays we read from can't be loop_vars and we fetch them from the outer context instead
Expand All @@ -423,7 +424,7 @@ def version_7(cls, ctx, node, **kwargs):

scan_output_names = []
# remove tensor array that are passed in to the loop
for idx, n, output_shape in reversed(scan_outputs):
for idx, n, output_shape, output_dtype in reversed(scan_outputs):
ctx.remove_node(n.name)
# make the node output bad
ctx.replace_all_inputs(n.output[0], "@@ALLOC") # ops=ctx.get_nodes()
Expand All @@ -433,7 +434,7 @@ def version_7(cls, ctx, node, **kwargs):
scan_output_names.append(body.outputs[idx])
del body.outputs[idx]
output_shapes.append(output_shape)
output_dtypes.append(output_dtypes[idx])
output_dtypes.append(output_dtype)
output_names.append(output_names[idx])
del output_shapes[idx]
del output_dtypes[idx]
Expand Down
2 changes: 2 additions & 0 deletions tf2onnx/onnx_opset/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@ def randuniform_int(cls, ctx, rand_node, rand_out, min_inp, max_inp):
dtype = ctx.get_dtype(rand_out)
min_node = ctx.get_node_by_output(min_inp)
max_node = ctx.get_node_by_output(max_inp)
ctx.set_dtype(rand_node.output[0], onnx_pb.TensorProto.FLOAT)
ctx.set_dtype(rand_out, onnx_pb.TensorProto.FLOAT)
if min_node.is_const() and max_node.is_const():
rand_node.set_attr('low', float(min_node.get_tensor_value()))
rand_node.set_attr('high', float(max_node.get_tensor_value()))
Expand Down
1 change: 1 addition & 0 deletions tf2onnx/onnx_opset/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -558,6 +558,7 @@ def version_11(cls, ctx, node, **kwargs):
to=dtypes[0])
ctx.set_dtype(cast_back_node.output[0], dtypes[0])
ctx.copy_shape(node.name, cast_back_node.output[0])
ctx.copy_dtype(node.input[0], node.output[0])


@tf_op("SquaredDistance", onnx_op="MeanSquaredDistance")
Expand Down
2 changes: 2 additions & 0 deletions tf2onnx/onnx_opset/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -654,6 +654,7 @@ def _convert(cls, ctx, node, **kwargs):
cast_back_node = ctx.make_node("Cast", [node.output[0]], dtypes=[origin_dtype], shapes=output_shapes,
name=node.name + "_castback", attr={"to": origin_dtype})
_ = ctx.insert_node_on_output(cast_back_node, node.output[0])
ctx.set_dtype(node.output[0], onnx_pb.TensorProto.FLOAT)

if len(node.input) < 3:
kernel_shape_tf = node.get_attr("ksize").ints
Expand Down Expand Up @@ -826,6 +827,7 @@ def version_6(cls, ctx, node, **kwargs):
to=x_dtype)
ctx.set_dtype(cast_back_node.output[0], x_dtype)
ctx.copy_shape(node.name, cast_back_node.output[0])
ctx.set_dtype(node.output[0], mean_type)

consumers = [ctx.find_output_consumers(output_name) for output_name in node.output[1:]]
if not any(consumers):
Expand Down
2 changes: 2 additions & 0 deletions tf2onnx/onnx_opset/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ def _wrap_concat_with_cast(ctx, node):
to=dtype)
ctx.set_dtype(output_cast.output[0], dtype)
ctx.copy_shape(output_name, output_cast.output[0])
ctx.set_dtype(node.output[0], onnx_pb.TensorProto.FLOAT)


@tf_op("Size")
Expand Down Expand Up @@ -1170,6 +1171,7 @@ def any_version_after10(cls, opset, ctx, node, **kwargs):
cast_out = ctx.insert_new_node_on_output("Cast", node.output[1], name=utils.make_name(node.name), to=dtypes[1])
ctx.set_dtype(cast_out.output[0], dtypes[1])
ctx.copy_shape(node.output[1], cast_out.output[0])
ctx.set_dtype(node.output[1], onnx_pb.TensorProto.INT64)

@classmethod
def version_10(cls, ctx, node, **kwargs):
Expand Down
1 change: 0 additions & 1 deletion tf2onnx/rewriter/quantization_ops_rewriter.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,6 @@ def create_qdq_nodes(g, match_results):
y_zero_point.output[0]],
shapes=[qdq_node_output_shape],
attr=attrs,
dtypes=[qdq_node_output_dtype],
name=utils.make_name("QuantLinearNode"))

g.set_shape(quant_node.output[0], qdq_node_output_shape)
Expand Down
2 changes: 1 addition & 1 deletion tf2onnx/tf_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -351,7 +351,7 @@ def read_tf_node_attrs(node):
"TI", "Tparams", "Tindices", "Tlen", "Tdim", "Tin", "dynamic_size", "Tmultiples",
"Tblock_shape", "Tcrops", "index_type", "Taxis", "U", "maxval",
"Tout", "Tlabels", "Tindex", "element_shape", "Targmax", "Tperm", "Tcond",
"T_threshold", "element_dtype", "shape_type", "_lower_using_switch_merge",
"T_threshold", "shape_type", "_lower_using_switch_merge",
"parallel_iterations", "_num_original_outputs", "output_types", "output_shapes",
"key_dtype", "value_dtype", "Tin", "Tout", "capacity", "component_types", "shapes",
"Toutput_types", "dense_shapes", "Tdense", "Tsegmentids", "Tshift", "Tnumsegments", "SrcT",
Expand Down

0 comments on commit ce67e78

Please sign in to comment.