Skip to content

Commit

Permalink
Implement RaggedTensorToTensor for tensors with dense (uniform) dims
Browse files Browse the repository at this point in the history
  • Loading branch information
TomWildenhain-Microsoft committed Aug 15, 2021
1 parent 2be4cf3 commit 1933199
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 2 deletions.
4 changes: 2 additions & 2 deletions tests/test_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -4700,10 +4700,10 @@ def func(splits1, splits2, rt_dense_values):
def test_ragged_tensor_to_tensor_row_ids(self):
ids_val1 = np.array([0, 0, 0, 2, 2], dtype=np.int32)
ids_val2 = np.array([0, 0, 2, 2, 2, 3, 3, 4], dtype=np.int32)
dense_vals_val = np.array([10, 20, 30, 40, 50, 60, 70, 80], dtype=np.float32)
dense_vals_val = make_xval([8, 2, 3])
def func(ids1, ids2, rt_dense_values):
x = tf.RaggedTensor.from_nested_value_rowids(rt_dense_values, [ids1, ids2], [4, 5])
y = x.to_tensor(default_value=7)
y = x.to_tensor(default_value=7, shape=[None, None, None, 2, None])
return tf.identity(y, name=_TFOUTPUT)
self._run_test_case(func, [_OUTPUT], {_INPUT: ids_val1, _INPUT1: ids_val2, _INPUT2: dense_vals_val})

Expand Down
21 changes: 21 additions & 0 deletions tf2onnx/onnx_opset/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -2460,6 +2460,8 @@ def version_11(cls, ctx, node, **kwargs):
# https://www.tensorflow.org/guide/ragged_tensor#multiple_ragged_dimensions
dense_values = node.input[-1]
nested_splits = node.input[:-1]
err_msg2 = "RaggedTensorToSparse conversion only supports tensors with no dense dimensions"
utils.make_sure(ctx.get_rank(dense_values) in [None, 1], err_msg2)
sparse_indices, dense_shape = ragged_nested_splits_to_sparse_indices(ctx, nested_splits, node.name)
ctx.replace_all_inputs(node.output[0], sparse_indices)
ctx.replace_all_inputs(node.output[1], dense_values)
Expand All @@ -2472,6 +2474,7 @@ class RaggedTensorToTensor:
@classmethod
def version_11(cls, ctx, node, **kwargs):
shape, values, default_value, *row_partition_tensors = node.input
has_uniform_dims = ctx.get_rank(values) != 1
partition_types = node.get_attr_value("row_partition_types")
layout_type = None
if len(partition_types) >= 2 and partition_types[0] == b'FIRST_DIM_SIZE' and \
Expand All @@ -2483,17 +2486,25 @@ def version_11(cls, ctx, node, **kwargs):

if layout_type == 'ROW_SPLITS':
nested_splits = row_partition_tensors
n_dims = len(nested_splits) + 1
sparse_indices, dense_shape = ragged_nested_splits_to_sparse_indices(ctx, nested_splits, node.name)
else:
utils.make_sure(layout_type == 'VALUE_ROWIDS', error_msg, partition_types)
first_dim = row_partition_tensors[0]
row_ids = row_partition_tensors[1:]
n_dims = len(row_ids) + 1
sparse_indices, dense_shape = ragged_nested_row_ids_to_sparse_indices(ctx, first_dim, row_ids, node.name)

# A shape of rank 0 means the natural shape should be used.
if ctx.get_rank(shape) != 0:
if ctx.get_dtype(shape) != TensorProto.INT64:
shape = ctx.make_node("Cast", [shape], attr={'to': TensorProto.INT64}).output[0]
if has_uniform_dims:
const_zero_unsq = ctx.make_const(utils.make_name("const_zero"), np.array([0], dtype=np.int64)).output[0]
const_n_unsq = ctx.make_const(utils.make_name("const_num_dims"),
np.array([n_dims], dtype=np.int64)).output[0]
shape = GraphBuilder(ctx).make_slice(
{'data': shape, 'starts': const_zero_unsq, 'ends': const_n_unsq, 'axes': [0]})
const_zero_int64 = ctx.make_const(utils.make_name("const_zero"), np.array(0, dtype=np.int64)).output[0]
unspec_dims = ctx.make_node("Less", [shape, const_zero_int64]).output[0]
out_shape = ctx.make_node("Where", [unspec_dims, dense_shape, shape]).output[0]
Expand All @@ -2505,6 +2516,16 @@ def version_11(cls, ctx, node, **kwargs):
values = ctx.make_node("Compress", [values, idx_in_bounds], attr={'axis': 0}).output[0]
else:
out_shape = dense_shape

if has_uniform_dims:
values_shape = ctx.make_node("Shape", [values]).output[0]
const_one_unsq = ctx.make_const(utils.make_name("const_one"), np.array([1], dtype=np.int64)).output[0]
max_int64 = np.array([utils.get_max_value(np.int64)], dtype=np.int64)
const_max_val_unsq = ctx.make_const(utils.make_name("max_int"), max_int64).output[0]
uniform_dims = GraphBuilder(ctx).make_slice(
{'data': values_shape, 'starts': const_one_unsq, 'ends': const_max_val_unsq, 'axes':[0]})
out_shape = ctx.make_node("Concat", [out_shape, uniform_dims], attr={'axis': 0}).output[0]

expand_node = ctx.make_node("Expand", [default_value, out_shape])
node.type = "ScatterND"
ctx.replace_inputs(node, [expand_node.output[0], sparse_indices, values])
Expand Down

0 comments on commit 1933199

Please sign in to comment.