diff --git a/tests/test_backend.py b/tests/test_backend.py index b7fca7dd6..ceaabd94a 100644 --- a/tests/test_backend.py +++ b/tests/test_backend.py @@ -3837,6 +3837,30 @@ def func(splits1, splits2, rt_dense_values): self._run_test_case(func, [_OUTPUT, _OUTPUT1, _OUTPUT2], {_INPUT: splits_val1, _INPUT1: splits_val2, _INPUT2: dense_vals_val}) + @check_tf_min_version("1.14", "ragged needs tf 1.14") + @check_opset_min_version(11, "CumSum") + def test_ragged_tensor_to_tensor(self): + splits_val1 = np.array([0, 1, 1, 5], dtype=np.int32) + splits_val2 = np.array([0, 3, 3, 5, 9, 10], dtype=np.int32) + dense_vals_val = np.array([10, 11, 12, 13, 14, 15, 16, 17, 18, 19], dtype=np.float32) + def func(splits1, splits2, rt_dense_values): + x = tf.RaggedTensor.from_nested_row_splits(rt_dense_values, [splits1, splits2], validate=True) + y = x.to_tensor(default_value=7) + return tf.identity(y, name=_TFOUTPUT) + self._run_test_case(func, [_OUTPUT], {_INPUT: splits_val1, _INPUT1: splits_val2, _INPUT2: dense_vals_val}) + + @check_tf_min_version("2.2", "ragged to_tensor with constrained shape") + @check_opset_min_version(11, "CumSum") + def test_ragged_tensor_to_tensor_constrain_shape(self): + splits_val1 = np.array([0, 1, 1, 5], dtype=np.int32) + splits_val2 = np.array([0, 3, 3, 5, 9, 10], dtype=np.int32) + dense_vals_val = np.array([10, 11, 12, 13, 14, 15, 16, 17, 18, 19], dtype=np.float32) + def func(splits1, splits2, rt_dense_values): + x = tf.RaggedTensor.from_nested_row_splits(rt_dense_values, [splits1, splits2], validate=True) + y = x.to_tensor(default_value=7, shape=[20, None, 2]) + return tf.identity(y, name=_TFOUTPUT) + self._run_test_case(func, [_OUTPUT], {_INPUT: splits_val1, _INPUT1: splits_val2, _INPUT2: dense_vals_val}) + @check_tf_min_version("1.14", "ragged needs tf 1.14") @check_opset_min_version(11, "Range") def test_ragged_range_float(self): diff --git a/tf2onnx/onnx_opset/tensor.py b/tf2onnx/onnx_opset/tensor.py index 4f7dd233a..d4fd359b6 100644 --- a/tf2onnx/onnx_opset/tensor.py +++ b/tf2onnx/onnx_opset/tensor.py @@ -2077,6 +2077,33 @@ def ragged_lengths_to_sparse_indices(ctx, ragged_lens): return num_rows, num_cols, row_indices, col_indices +def ragged_nested_splits_to_sparse_indices(ctx, nested_splits, op_name_scope): + sparse_indices = None + dense_shape_dims = [] + for split in nested_splits: + if ctx.get_dtype(split) != TensorProto.INT64: + split = ctx.make_node("Cast", [split], attr={'to': TensorProto.INT64}).output[0] + max_int64 = int(utils.get_max_value(np.int64)) + slice1 = GraphBuilder(ctx).make_slice( + {"data": split, "ends": [max_int64], "starts": [1], "axes": [0]}) + slice2 = GraphBuilder(ctx).make_slice( + {"data": split, "ends": [-1], "starts": [0], "axes": [0]}) + ragged_lens = ctx.make_node("Sub", [slice1, slice2]).output[0] + num_rows, num_cols, row_indices, col_indices = ragged_lengths_to_sparse_indices(ctx, ragged_lens) + if not dense_shape_dims: + dense_shape_dims.append(num_rows) + dense_shape_dims.append(num_cols) + if sparse_indices is None: + row_indices = GraphBuilder(ctx).make_unsqueeze({"data": row_indices, "axes": [1]}) + else: + row_indices = ctx.make_node("Gather", [sparse_indices, row_indices]).output[0] + col_indices = GraphBuilder(ctx).make_unsqueeze({"data": col_indices, "axes": [1]}) + sparse_indices = ctx.make_node("Concat", [row_indices, col_indices], attr={'axis': 1}, + op_name_scope=op_name_scope).output[0] + dense_shape = ctx.make_node("Concat", dense_shape_dims, attr={'axis': 0}, op_name_scope=op_name_scope).output[0] + return sparse_indices, dense_shape + + @tf_op("RaggedTensorToSparse") class RaggedTensorToSparse: @classmethod @@ -2084,36 +2111,45 @@ def version_11(cls, ctx, node, **kwargs): # https://www.tensorflow.org/guide/ragged_tensor#multiple_ragged_dimensions dense_values = node.input[-1] nested_splits = node.input[:-1] - sparse_indices = None - dense_shape_dims = [] - for split in nested_splits: - if ctx.get_dtype(split) != TensorProto.INT64: - split = ctx.make_node("Cast", [split], attr={'to': TensorProto.INT64}).output[0] - max_int64 = int(utils.get_max_value(np.int64)) - slice1 = GraphBuilder(ctx).make_slice( - {"data": split, "ends": [max_int64], "starts": [1], "axes": [0]}) - slice2 = GraphBuilder(ctx).make_slice( - {"data": split, "ends": [-1], "starts": [0], "axes": [0]}) - ragged_lens = ctx.make_node("Sub", [slice1, slice2]).output[0] - num_rows, num_cols, row_indices, col_indices = ragged_lengths_to_sparse_indices(ctx, ragged_lens) - if not dense_shape_dims: - dense_shape_dims.append(num_rows) - dense_shape_dims.append(num_cols) - if sparse_indices is None: - row_indices = GraphBuilder(ctx).make_unsqueeze({"data": row_indices, "axes": [1]}) - else: - row_indices = ctx.make_node("Gather", [sparse_indices, row_indices]).output[0] - col_indices = GraphBuilder(ctx).make_unsqueeze({"data": col_indices, "axes": [1]}) - sparse_indices = ctx.make_node("Concat", [row_indices, col_indices], attr={'axis': 1}, - op_name_scope=node.name).output[0] - dense_shape = ctx.make_node("Concat", dense_shape_dims, attr={'axis': 0}, op_name_scope=node.name).output[0] - + sparse_indices, dense_shape = ragged_nested_splits_to_sparse_indices(ctx, nested_splits, node.name) ctx.replace_all_inputs(node.output[0], sparse_indices) ctx.replace_all_inputs(node.output[1], dense_values) ctx.replace_all_inputs(node.output[2], dense_shape) ctx.remove_node(node.name) +@tf_op("RaggedTensorToTensor") +class RaggedTensorToTensor: + @classmethod + def version_11(cls, ctx, node, **kwargs): + shape, values, default_value, *row_partition_tensors = node.input + partition_types = node.get_attr_value("row_partition_types") + error_msg = "Only ROW_SPLITS partition type is supported for RaggedTensorToTensor. types: %r" + utils.make_sure(all(t == b'ROW_SPLITS' for t in partition_types), error_msg, partition_types) + nested_splits = row_partition_tensors + sparse_indices, dense_shape = ragged_nested_splits_to_sparse_indices(ctx, nested_splits, node.name) + # A shape of rank 0 means the natural shape should be used. + if ctx.get_rank(shape) != 0: + if ctx.get_dtype(shape) != TensorProto.INT64: + shape = ctx.make_node("Cast", [shape], attr={'to': TensorProto.INT64}).output[0] + const_zero_int64 = ctx.make_const(utils.make_name("const_zero"), np.array(0, dtype=np.int64)).output[0] + unspec_dims = ctx.make_node("Less", [shape, const_zero_int64]).output[0] + int_max_val = np.array([utils.get_max_value(np.int64)], dtype=np.int64) + const_int_max = ctx.make_const(utils.make_name("largest_int_val"), int_max_val).output[0] + out_shape = ctx.make_node("Where", [unspec_dims, dense_shape, shape]).output[0] + out_shape_unsq = GraphBuilder(ctx).make_unsqueeze({'data': out_shape, 'axes': [0]}) + amt_idx_in_bounds = ctx.make_node("Sub", [out_shape_unsq, sparse_indices]).output[0] + amt_in_bounds_flat = ctx.make_node("ReduceMin", [amt_idx_in_bounds], attr={'axes': [1], 'keepdims': False}) + idx_in_bounds = ctx.make_node("Greater", [amt_in_bounds_flat.output[0], const_zero_int64]).output[0] + sparse_indices = ctx.make_node("Compress", [sparse_indices, idx_in_bounds], attr={'axis': 0}).output[0] + values = ctx.make_node("Compress", [values, idx_in_bounds], attr={'axis': 0}).output[0] + else: + out_shape = dense_shape + expand_node = ctx.make_node("Expand", [default_value, out_shape]) + node.type = "ScatterND" + ctx.replace_inputs(node, [expand_node.output[0], sparse_indices, values]) + + @tf_op("RaggedRange") class RaggedRange: @classmethod