Skip to content

Commit

Permalink
Add support for GatherV2 batch_dims attr (#1329)
Browse files Browse the repository at this point in the history
* Add support for GatherV2 batch_dims attr

Signed-off-by: Tom Wildenhain <tomwi@microsoft.com>

* Fix tests

Signed-off-by: Tom Wildenhain <tomwi@microsoft.com>
  • Loading branch information
TomWildenhain-Microsoft authored Feb 11, 2021
1 parent fea121d commit c23ef70
Show file tree
Hide file tree
Showing 2 changed files with 59 additions and 0 deletions.
20 changes: 20 additions & 0 deletions tests/test_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -871,6 +871,26 @@ def func(x):
return tf.identity(x_, name=_TFOUTPUT)
self._run_test_case(func, [_OUTPUT], {_INPUT: x_val})

@check_tf_min_version("1.14")
@check_opset_min_version(12, "GatherND with batch_dims")
def test_gather_batch_dims_no_trans(self):
x_val = np.arange(2 * 2 * 3 * 5 * 4, dtype=np.float32).reshape((2, 2, 3, 5, 4))
idx_val = np.array([[[1, 0, 2, 0], [1, 1, 1, 0]], [[0, 0, 0, 0], [2, 1, 1, 0]]], dtype=np.int32)
def func(x, idx):
x_ = tf.gather(x, idx, batch_dims=2, axis=2)
return tf.identity(x_, name=_TFOUTPUT)
self._run_test_case(func, [_OUTPUT], {_INPUT: x_val, _INPUT1: idx_val})

@check_tf_min_version("1.14")
@check_opset_min_version(12, "GatherND with batch_dims")
def test_gather_batch_dims(self):
x_val = np.arange(2 * 2 * 3 * 5 * 4, dtype=np.float32).reshape((2, 2, 3, 5, 4))
idx_val = np.array([[[1, 0, 2, 0], [1, 1, 1, 0]], [[0, 0, 0, 0], [2, 1, 1, 0]]], dtype=np.int32)
def func(x, idx):
x_ = tf.gather(x, idx, batch_dims=2, axis=3)
return tf.identity(x_, name=_TFOUTPUT)
self._run_test_case(func, [_OUTPUT], {_INPUT: x_val, _INPUT1: idx_val})

@check_opset_min_version(10, "Slice")
def test_roll_axis_scalar(self):
x_val = np.arange(4 * 3 * 5 * 2, dtype=np.float32).reshape((4, 3, 5, 2))
Expand Down
39 changes: 39 additions & 0 deletions tf2onnx/onnx_opset/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -423,7 +423,10 @@ class GatherV2:
@classmethod
def version_1(cls, ctx, node, **kwargs):
# for GatherV2 axis come as input
err_msg = "Opset 12 required for batch_dims attribute of GatherV2"
utils.make_sure(node.get_attr_value("batch_dims", 0) == 0, err_msg)
node.type = "Gather"
utils.make_sure(node.inputs[2].is_const(), "Axis of GatherV2 node must be constant")
axis = node.inputs[2].get_tensor_value()
ctx.remove_input(node, node.input[2], 2)
node.set_attr("axis", axis)
Expand All @@ -433,6 +436,42 @@ def version_11(cls, ctx, node, **kwargs):
# no change
cls.version_1(ctx, node, **kwargs)

@classmethod
def version_12(cls, ctx, node, **kwargs):
batch_dims = node.get_attr_value("batch_dims", 0)
if batch_dims == 0:
cls.version_1(ctx, node, **kwargs)
return
# If batch_dims is not zero, use GatherND to simulate Gather with batch dims.
data_inp, indices_inp, axis_inp = node.input
utils.make_sure(node.inputs[2].is_const(), "Axis of GatherV2 node must be constant")
axis = node.inputs[2].get_tensor_value()
ctx.remove_input(node, axis_inp, 2)
if ctx.get_dtype(indices_inp) != TensorProto.INT64:
indices_inp = ctx.make_node("Cast", [indices_inp], attr={'to': TensorProto.INT64}).output[0]
unperm = None
# GatherND doesn't take an axis so we have to transpose stuff around
if axis != batch_dims:
data_rank = ctx.get_rank(data_inp)
indices_rank = ctx.get_rank(indices_inp)
result_rank = data_rank + indices_rank - 1 - batch_dims
shift_amt = axis - batch_dims
err_msg = "Cannot convert GatherV2 with batch dims since inputs have unknown ranks."
utils.make_sure(data_rank is not None and indices_rank is not None, err_msg)
perm = list(range(data_rank))
perm = perm[:batch_dims] + perm[axis:axis+1] + perm[batch_dims:axis] + perm[axis+1:]
data_inp = ctx.make_node("Transpose", [data_inp], attr={'perm': perm}).output[0]
ctx.replace_input(node, node.input[0], data_inp, 0)
unperm = list(range(result_rank))
j = indices_rank+shift_amt
unperm = unperm[:batch_dims] + unperm[indices_rank:j] + unperm[batch_dims:indices_rank] + unperm[j:]
node.type = "GatherND"
unsqueeze_node = GraphBuilder(ctx).make_unsqueeze({'data': indices_inp, 'axes': [-1]})
ctx.replace_input(node, node.input[1], unsqueeze_node, 1)
if unperm is not None:
ctx.update_node_shape_dtype(node, override=True)
ctx.insert_new_node_on_output("Transpose", node.output[0], perm=unperm)


def _make_gathernd_inner_loop(ctx, params, index, dtype):
"""create the inner loop for GatherNd."""
Expand Down

0 comments on commit c23ef70

Please sign in to comment.