Skip to content

Commit

Permalink
Implement VALUE_ROWIDS format for ragged to tensor
Browse files Browse the repository at this point in the history
Signed-off-by: Tom Wildenhain <tomwi@microsoft.com>
  • Loading branch information
TomWildenhain-Microsoft committed Aug 11, 2021
1 parent 2def374 commit f18c382
Show file tree
Hide file tree
Showing 2 changed files with 71 additions and 4 deletions.
13 changes: 13 additions & 0 deletions tests/test_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -4666,6 +4666,19 @@ def func(splits1, splits2, rt_dense_values):
return tf.identity(y, name=_TFOUTPUT)
self._run_test_case(func, [_OUTPUT], {_INPUT: splits_val1, _INPUT1: splits_val2, _INPUT2: dense_vals_val})

@check_tf_min_version("1.14", "ragged needs tf 1.14")
@check_opset_min_version(11, "CumSum")
@skip_tflite("unknown rank")
def test_ragged_tensor_to_tensor_row_ids(self):
ids_val1 = np.array([0, 0, 0, 2, 2], dtype=np.int32)
ids_val2 = np.array([0, 0, 2, 2, 2, 3, 3, 4], dtype=np.int32)
dense_vals_val = np.array([10, 20, 30, 40, 50, 60, 70, 80], dtype=np.float32)
def func(ids1, ids2, rt_dense_values):
x = tf.RaggedTensor.from_nested_value_rowids(rt_dense_values, [ids1, ids2], [4, 5])
y = x.to_tensor(default_value=7)
return tf.identity(y, name=_TFOUTPUT)
self._run_test_case(func, [_OUTPUT], {_INPUT: ids_val1, _INPUT1: ids_val2, _INPUT2: dense_vals_val})

@check_tf_min_version("2.2", "ragged to_tensor with constrained shape")
@check_opset_min_version(11, "CumSum")
def test_ragged_tensor_to_tensor_constrain_shape(self):
Expand Down
62 changes: 58 additions & 4 deletions tf2onnx/onnx_opset/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -2385,6 +2385,25 @@ def ragged_lengths_to_sparse_indices(ctx, ragged_lens):
return num_rows, num_cols, row_indices, col_indices


def ragged_row_ids_to_sparse_indices(ctx, row_ids):
_, indices, _, counts = ctx.make_node("Unique", [row_ids], attr={'axis': 0}, output_count=4).output
num_cols = ctx.make_node("ReduceMax", [counts], attr={'axes': [0], 'keepdims': True}).output[0]
const_one = ctx.make_const(utils.make_name("const_one"), np.array(1, np.int64)).output[0]
const_zero_unsq = ctx.make_const(utils.make_name("const_zero"), np.array([0], np.int64)).output[0]
const_zero = ctx.make_const(utils.make_name("const_zero"), np.array(0, np.int64)).output[0]
const_neg_one_unsq = ctx.make_const(utils.make_name("const_neg_one"), np.array([-1], np.int64)).output[0]
one_minus_cnt = ctx.make_node("Sub", [const_one, counts]).output[0]
cnts_prefixed = ctx.make_node("Concat", [const_zero_unsq, one_minus_cnt], attr={'axis': 0}).output[0]
cnts_shifted = GraphBuilder(ctx).make_slice(
{'data': cnts_prefixed, 'starts': const_zero_unsq, 'ends': const_neg_one_unsq, 'axes': [0]})
ids_shape = ctx.make_node("Shape", [row_ids]).output[0]
one_tensor = helper.make_tensor("value", onnx_pb.TensorProto.INT64, dims=[1], vals=[1])
ones_of_shape = ctx.make_node("ConstantOfShape", [ids_shape], attr={'value': one_tensor}).output[0]
deltas = ctx.make_node("ScatterElements", [ones_of_shape, indices, cnts_shifted], attr={'axis': 0}).output[0]
col_indices = ctx.make_node("CumSum", [deltas, const_zero]).output[0]
return num_cols, col_indices


def ragged_nested_splits_to_sparse_indices(ctx, nested_splits, op_name_scope):
sparse_indices = None
dense_shape_dims = []
Expand Down Expand Up @@ -2412,6 +2431,28 @@ def ragged_nested_splits_to_sparse_indices(ctx, nested_splits, op_name_scope):
return sparse_indices, dense_shape


def ragged_nested_row_ids_to_sparse_indices(ctx, num_rows, nested_row_ids, op_name_scope):
sparse_indices = None
if ctx.get_dtype(num_rows) != TensorProto.INT64:
num_rows = ctx.make_node("Cast", [num_rows], attr={'to': TensorProto.INT64}).output[0]
num_rows = GraphBuilder(ctx).make_unsqueeze({"data": num_rows, "axes": [0]})
dense_shape_dims = [num_rows]
for row_ids in nested_row_ids:
if ctx.get_dtype(row_ids) != TensorProto.INT64:
row_ids = ctx.make_node("Cast", [row_ids], attr={'to': TensorProto.INT64}).output[0]
num_cols, col_indices = ragged_row_ids_to_sparse_indices(ctx, row_ids)
dense_shape_dims.append(num_cols)
if sparse_indices is None:
row_indices = GraphBuilder(ctx).make_unsqueeze({"data": row_ids, "axes": [1]})
else:
row_indices = ctx.make_node("Gather", [sparse_indices, row_ids]).output[0]
col_indices = GraphBuilder(ctx).make_unsqueeze({"data": col_indices, "axes": [1]})
sparse_indices = ctx.make_node("Concat", [row_indices, col_indices], attr={'axis': 1},
op_name_scope=op_name_scope).output[0]
dense_shape = ctx.make_node("Concat", dense_shape_dims, attr={'axis': 0}, op_name_scope=op_name_scope).output[0]
return sparse_indices, dense_shape


@tf_op("RaggedTensorToSparse")
class RaggedTensorToSparse:
@classmethod
Expand All @@ -2432,10 +2473,23 @@ class RaggedTensorToTensor:
def version_11(cls, ctx, node, **kwargs):
shape, values, default_value, *row_partition_tensors = node.input
partition_types = node.get_attr_value("row_partition_types")
error_msg = "Only ROW_SPLITS partition type is supported for RaggedTensorToTensor. types: %r"
utils.make_sure(all(t == b'ROW_SPLITS' for t in partition_types), error_msg, partition_types)
nested_splits = row_partition_tensors
sparse_indices, dense_shape = ragged_nested_splits_to_sparse_indices(ctx, nested_splits, node.name)
layout_type = None
if len(partition_types) >= 2 and partition_types[0] == b'FIRST_DIM_SIZE' and \
all(t == b'VALUE_ROWIDS' for t in partition_types[1:]):
layout_type = 'VALUE_ROWIDS'
elif all(t == b'ROW_SPLITS' for t in partition_types):
layout_type = 'ROW_SPLITS'
error_msg = "Only ROW_SPLITS partition and VALUE_ROWIDS types supported for RaggedTensorToTensor. types: %r"

if layout_type == 'ROW_SPLITS':
nested_splits = row_partition_tensors
sparse_indices, dense_shape = ragged_nested_splits_to_sparse_indices(ctx, nested_splits, node.name)
else:
utils.make_sure(layout_type == 'VALUE_ROWIDS', error_msg, partition_types)
first_dim = row_partition_tensors[0]
row_ids = row_partition_tensors[1:]
sparse_indices, dense_shape = ragged_nested_row_ids_to_sparse_indices(ctx, first_dim, row_ids, node.name)

# A shape of rank 0 means the natural shape should be used.
if ctx.get_rank(shape) != 0:
if ctx.get_dtype(shape) != TensorProto.INT64:
Expand Down

0 comments on commit f18c382

Please sign in to comment.