Skip to content

Commit

Permalink
Implement RaggedTensorToTensor conversion
Browse files Browse the repository at this point in the history
Signed-off-by: Tom Wildenhain <tomwi@microsoft.com>
  • Loading branch information
TomWildenhain-Microsoft committed Feb 10, 2021
1 parent 8aa1127 commit 6a55d93
Show file tree
Hide file tree
Showing 2 changed files with 84 additions and 24 deletions.
24 changes: 24 additions & 0 deletions tests/test_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -3837,6 +3837,30 @@ def func(splits1, splits2, rt_dense_values):
self._run_test_case(func, [_OUTPUT, _OUTPUT1, _OUTPUT2],
{_INPUT: splits_val1, _INPUT1: splits_val2, _INPUT2: dense_vals_val})

@check_tf_min_version("1.14", "ragged needs tf 1.14")
@check_opset_min_version(11, "CumSum")
def test_ragged_tensor_to_tensor(self):
splits_val1 = np.array([0, 1, 1, 5], dtype=np.int32)
splits_val2 = np.array([0, 3, 3, 5, 9, 10], dtype=np.int32)
dense_vals_val = np.array([10, 11, 12, 13, 14, 15, 16, 17, 18, 19], dtype=np.float32)
def func(splits1, splits2, rt_dense_values):
x = tf.RaggedTensor.from_nested_row_splits(rt_dense_values, [splits1, splits2], validate=True)
y = x.to_tensor(default_value=7)
return tf.identity(y, name=_TFOUTPUT)
self._run_test_case(func, [_OUTPUT], {_INPUT: splits_val1, _INPUT1: splits_val2, _INPUT2: dense_vals_val})

@check_tf_min_version("2.2", "ragged to_tensor with constrained shape")
@check_opset_min_version(11, "CumSum")
def test_ragged_tensor_to_tensor_constrain_shape(self):
splits_val1 = np.array([0, 1, 1, 5], dtype=np.int32)
splits_val2 = np.array([0, 3, 3, 5, 9, 10], dtype=np.int32)
dense_vals_val = np.array([10, 11, 12, 13, 14, 15, 16, 17, 18, 19], dtype=np.float32)
def func(splits1, splits2, rt_dense_values):
x = tf.RaggedTensor.from_nested_row_splits(rt_dense_values, [splits1, splits2], validate=True)
y = x.to_tensor(default_value=7, shape=[20, None, 2])
return tf.identity(y, name=_TFOUTPUT)
self._run_test_case(func, [_OUTPUT], {_INPUT: splits_val1, _INPUT1: splits_val2, _INPUT2: dense_vals_val})

@check_tf_min_version("1.14", "ragged needs tf 1.14")
@check_opset_min_version(11, "Range")
def test_ragged_range_float(self):
Expand Down
84 changes: 60 additions & 24 deletions tf2onnx/onnx_opset/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -2077,43 +2077,79 @@ def ragged_lengths_to_sparse_indices(ctx, ragged_lens):
return num_rows, num_cols, row_indices, col_indices


def ragged_nested_splits_to_sparse_indices(ctx, nested_splits, op_name_scope):
sparse_indices = None
dense_shape_dims = []
for split in nested_splits:
if ctx.get_dtype(split) != TensorProto.INT64:
split = ctx.make_node("Cast", [split], attr={'to': TensorProto.INT64}).output[0]
max_int64 = int(utils.get_max_value(np.int64))
slice1 = GraphBuilder(ctx).make_slice(
{"data": split, "ends": [max_int64], "starts": [1], "axes": [0]})
slice2 = GraphBuilder(ctx).make_slice(
{"data": split, "ends": [-1], "starts": [0], "axes": [0]})
ragged_lens = ctx.make_node("Sub", [slice1, slice2]).output[0]
num_rows, num_cols, row_indices, col_indices = ragged_lengths_to_sparse_indices(ctx, ragged_lens)
if not dense_shape_dims:
dense_shape_dims.append(num_rows)
dense_shape_dims.append(num_cols)
if sparse_indices is None:
row_indices = GraphBuilder(ctx).make_unsqueeze({"data": row_indices, "axes": [1]})
else:
row_indices = ctx.make_node("Gather", [sparse_indices, row_indices]).output[0]
col_indices = GraphBuilder(ctx).make_unsqueeze({"data": col_indices, "axes": [1]})
sparse_indices = ctx.make_node("Concat", [row_indices, col_indices], attr={'axis': 1},
op_name_scope=op_name_scope).output[0]
dense_shape = ctx.make_node("Concat", dense_shape_dims, attr={'axis': 0}, op_name_scope=op_name_scope).output[0]
return sparse_indices, dense_shape


@tf_op("RaggedTensorToSparse")
class RaggedTensorToSparse:
@classmethod
def version_11(cls, ctx, node, **kwargs):
# https://www.tensorflow.org/guide/ragged_tensor#multiple_ragged_dimensions
dense_values = node.input[-1]
nested_splits = node.input[:-1]
sparse_indices = None
dense_shape_dims = []
for split in nested_splits:
if ctx.get_dtype(split) != TensorProto.INT64:
split = ctx.make_node("Cast", [split], attr={'to': TensorProto.INT64}).output[0]
max_int64 = int(utils.get_max_value(np.int64))
slice1 = GraphBuilder(ctx).make_slice(
{"data": split, "ends": [max_int64], "starts": [1], "axes": [0]})
slice2 = GraphBuilder(ctx).make_slice(
{"data": split, "ends": [-1], "starts": [0], "axes": [0]})
ragged_lens = ctx.make_node("Sub", [slice1, slice2]).output[0]
num_rows, num_cols, row_indices, col_indices = ragged_lengths_to_sparse_indices(ctx, ragged_lens)
if not dense_shape_dims:
dense_shape_dims.append(num_rows)
dense_shape_dims.append(num_cols)
if sparse_indices is None:
row_indices = GraphBuilder(ctx).make_unsqueeze({"data": row_indices, "axes": [1]})
else:
row_indices = ctx.make_node("Gather", [sparse_indices, row_indices]).output[0]
col_indices = GraphBuilder(ctx).make_unsqueeze({"data": col_indices, "axes": [1]})
sparse_indices = ctx.make_node("Concat", [row_indices, col_indices], attr={'axis': 1},
op_name_scope=node.name).output[0]
dense_shape = ctx.make_node("Concat", dense_shape_dims, attr={'axis': 0}, op_name_scope=node.name).output[0]

sparse_indices, dense_shape = ragged_nested_splits_to_sparse_indices(ctx, nested_splits, node.name)
ctx.replace_all_inputs(node.output[0], sparse_indices)
ctx.replace_all_inputs(node.output[1], dense_values)
ctx.replace_all_inputs(node.output[2], dense_shape)
ctx.remove_node(node.name)


@tf_op("RaggedTensorToTensor")
class RaggedTensorToTensor:
@classmethod
def version_11(cls, ctx, node, **kwargs):
shape, values, default_value, *row_partition_tensors = node.input
partition_types = node.get_attr_value("row_partition_types")
error_msg = "Only ROW_SPLITS partition type is supported for RaggedTensorToTensor. types: %r"
utils.make_sure(all(t == b'ROW_SPLITS' for t in partition_types), error_msg, partition_types)
nested_splits = row_partition_tensors
sparse_indices, dense_shape = ragged_nested_splits_to_sparse_indices(ctx, nested_splits, node.name)
# A shape of rank 0 means the natural shape should be used.
if ctx.get_rank(shape) != 0:
if ctx.get_dtype(shape) != TensorProto.INT64:
shape = ctx.make_node("Cast", [shape], attr={'to': TensorProto.INT64}).output[0]
const_zero_int64 = ctx.make_const(utils.make_name("const_zero"), np.array(0, dtype=np.int64)).output[0]
unspec_dims = ctx.make_node("Less", [shape, const_zero_int64]).output[0]
int_max_val = np.array([utils.get_max_value(np.int64)], dtype=np.int64)
const_int_max = ctx.make_const(utils.make_name("largest_int_val"), int_max_val).output[0]
out_shape = ctx.make_node("Where", [unspec_dims, dense_shape, shape]).output[0]
out_shape_unsq = GraphBuilder(ctx).make_unsqueeze({'data': out_shape, 'axes': [0]})
amt_idx_in_bounds = ctx.make_node("Sub", [out_shape_unsq, sparse_indices]).output[0]
amt_in_bounds_flat = ctx.make_node("ReduceMin", [amt_idx_in_bounds], attr={'axes': [1], 'keepdims': False})
idx_in_bounds = ctx.make_node("Greater", [amt_in_bounds_flat.output[0], const_zero_int64]).output[0]
sparse_indices = ctx.make_node("Compress", [sparse_indices, idx_in_bounds], attr={'axis': 0}).output[0]
values = ctx.make_node("Compress", [values, idx_in_bounds], attr={'axis': 0}).output[0]
else:
out_shape = dense_shape
expand_node = ctx.make_node("Expand", [default_value, out_shape])
node.type = "ScatterND"
ctx.replace_inputs(node, [expand_node.output[0], sparse_indices, values])


@tf_op("RaggedRange")
class RaggedRange:
@classmethod
Expand Down

0 comments on commit 6a55d93

Please sign in to comment.