Skip to content

Commit

Permalink
add unit32 unit64 type support (#1808)
Browse files Browse the repository at this point in the history
fixes #1802
Signed-off-by: hwangdeyu <dejack953@outlook.com>
Co-authored-by: fatcat-z <jiz@microsoft.com>
  • Loading branch information
hwangdeyu authored Dec 23, 2021
1 parent 6691850 commit 5cd3b5b
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 0 deletions.
6 changes: 6 additions & 0 deletions tests/test_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -2353,6 +2353,12 @@ def func(x):
return tf.identity(x_, name=_TFOUTPUT)
self._run_test_case(func, [_OUTPUT], {_INPUT: x_val})

x_val = np.array([1, 2, 3, 4], dtype=np.uint32).reshape((2, 2))
def func(x):
x_ = tf.cast(x, tf.uint64)
return tf.identity(x_, name=_TFOUTPUT)
self._run_test_case(func, [_OUTPUT], {_INPUT: x_val})

@check_opset_min_version(7, "sign")
def test_sign(self):
x_vals = [np.array([1.0, 2.0, 0.0, -1.0, 0.0, -2.0], dtype=np.float32).reshape((2, 3)),
Expand Down
2 changes: 2 additions & 0 deletions tf2onnx/tf_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@
types_pb2.DT_INT8: onnx_pb.TensorProto.INT8,
types_pb2.DT_UINT8: onnx_pb.TensorProto.UINT8,
types_pb2.DT_UINT16: onnx_pb.TensorProto.UINT16,
types_pb2.DT_UINT32: onnx_pb.TensorProto.UINT32,
types_pb2.DT_UINT64: onnx_pb.TensorProto.UINT64,
types_pb2.DT_INT64: onnx_pb.TensorProto.INT64,
types_pb2.DT_STRING: onnx_pb.TensorProto.STRING,
types_pb2.DT_COMPLEX64: onnx_pb.TensorProto.COMPLEX64,
Expand Down
4 changes: 4 additions & 0 deletions tf2onnx/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@
onnx_pb.TensorProto.INT8: np.int8,
onnx_pb.TensorProto.UINT8: np.uint8,
onnx_pb.TensorProto.UINT16: np.uint16,
onnx_pb.TensorProto.UINT32: np.uint32,
onnx_pb.TensorProto.UINT64: np.uint64,
onnx_pb.TensorProto.INT64: np.int64,
onnx_pb.TensorProto.UINT64: np.uint64,
onnx_pb.TensorProto.BOOL: np.bool,
Expand All @@ -58,6 +60,8 @@
onnx_pb.TensorProto.INT8: "int8",
onnx_pb.TensorProto.UINT8: "uint8",
onnx_pb.TensorProto.UINT16: "uint16",
onnx_pb.TensorProto.UINT32: "uint32",
onnx_pb.TensorProto.UINT64: "uint64",
onnx_pb.TensorProto.INT64: "int64",
onnx_pb.TensorProto.STRING: "string",
onnx_pb.TensorProto.BOOL: "bool",
Expand Down

0 comments on commit 5cd3b5b

Please sign in to comment.