Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Relay, ONNX] Support gather_nd batch_dims attribute for TF/ONNX #8084

Merged
merged 17 commits into from
May 21, 2021

Conversation

masahi
Copy link
Member

@masahi masahi commented May 20, 2021

Similar to #7951 that added batch_dim to take (gather) op, I added batch_dims option to gather_nd that is useful for TF/ONNX.

https://www.tensorflow.org/api_docs/python/tf/gather_nd
https://github.com/onnx/onnx/blob/master/docs/Operators.md#GatherND (Opset 12 or higher)

please review @Laurawly @mbrookhart @comaniac @yongwww

Copy link
Contributor

@mbrookhart mbrookhart left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Copy link
Contributor

@comaniac comaniac left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@tkonolige
Copy link
Contributor

Our implementation already has an implicit batch dimension. The current implementation is "Given data with shape (X_0, X_1, …, X_{N-1}) and indices with shape (M, Y_0, …, Y_{K-1}), the output will have shape (Y_0, …, Y_{K-1}, X_M, …, X_{N-1}), where M <= N. If M == N, output shape will simply be (Y_0, …, Y_{K-1})." X_M, ..., X_{N-1} is the implicit batch dimension. How does the explicit batch size differ from this. And should we consider unifying the two?

@masahi
Copy link
Member Author

masahi commented May 20, 2021

X_M, ..., X_{N-1} is the implicit batch dimension

I don't get what you meant by "implicit batch dimension". X_M, ..., X_{N-1} are axes of the input that are not indexed and thus simply copied. batch_dims tells from which axis the indexing starts.

Our current gather_nd is identical with mxnet one in
https://mxnet.apache.org/versions/1.6/api/r/docs/api/mx.symbol.gather_nd.html, which is the same as TF gather_nd and ONNX GatherND except that

  • indexing M tuples are in the first axis rather than the last.
  • batch dims is always 0

(There is an open request to add batch_dims support to the mxnet op apache/mxnet#9998)

So right now the output is

output[y_0, ..., y_{K-1}, x_M, ..., x_{N-1}] = data[indices[0, y_0, ..., y_{K-1}],
...,
indices[M-1, y_0, ..., y_{K-1}],
x_M, ..., x_{N-1}]

With batch_dims B, it becomes (I hope it is correct but didn't check deeply)

output[y_0, ..., y_{K-1}, x_M, ..., x_{N-1}] = data[y_0, ..., y_{B-1}, indices[0, y_0, ..., y_{K-1}],
...,
indices[M-1, y_0, ..., y_{K-1}],
x_{M+B}, ..., x_{N-1}]

I'm going to update the doc to the following if this makes sense @tkonolige

Optionally, batch_dims, the number of batch dimensions, can be given, whose
default value is 0.

Let B denote batch_dims, and data, indices shape be (X_0, X_1, ..., X_{N-1}),
(M, Y_0, ..., Y_{K-1}) respectively. When B > 0, indexing will start from the B-th axis,
and it must be the case that X_0, ... X_{B-1} == Y_0, ... Y_{B-1}.

The output will have shape
(Y_0, ..., Y_{K-1}, X_{M+B}, ..., X_{N-1}), where M + B <= N. If M + B == N,
output shape will simply be (Y_0, ..., Y_{K-1}).

@tkonolige
Copy link
Contributor

Shouldn't the batch size appear in the output shape? I think it should be (X_0, ..., X_{B-1}, Y_0, ..., Y_{K-1}, X_{M+B}, .. X_{N-1})

@masahi
Copy link
Member Author

masahi commented May 20, 2021

As the onnx doc says, https://github.com/onnx/onnx/blob/master/docs/Operators.md#GatherND, the leading B number of dimensions of data tensor and indices are representing the batches, so there is a constraint X_0, ... X_{B-1} == Y_0, ... Y_{B-1}.

When I wrote the output shape as (Y_0, ..., Y_{K-1}, X_{M+B}, ..., X_{N-1}), more precisely it means (Y_0, ..., Y_{B-1}, ... Y_{K-1}, X_{M+B}, ..., X_{N-1}), which is equivalent to (X_0, ..., X_{B-1}, ... Y_{K-1}, X_{M+B}, ..., X_{N-1}).

@tkonolige
Copy link
Contributor

Ah thats what I was missing. I'd use the expanded definition with X_0... because I think it is clearer for users.

@masahi
Copy link
Member Author

masahi commented May 20, 2021

I'd use the expanded definition with X_0... because I think it is clearer for users

Ok since the definition with X_0... only applies when B > 0, I added output shape descriptions for two cases (B > 0 and B == 0).

@masahi masahi merged commit 0d38a92 into apache:main May 21, 2021
@masahi
Copy link
Member Author

masahi commented May 21, 2021

thanks @mbrookhart @comaniac @tkonolige

trevor-m pushed a commit to trevor-m/tvm that referenced this pull request Jun 17, 2021
…che#8084)

* Add GatherND batch_dim support

* adding tests

* test working

* improved reference code

* refactor ref func

* batch dim 2 tests from tf all passed

* batch_dim -> batch_dims

* add example

* minor change

* add onnx test

* fix onnx version

* fix lint

* remove move on batch_dims

* fix pylint

* fix compiler warning

* add shape constraint for batch_dim and update doc

* make the output shape doc clearer
trevor-m pushed a commit to neo-ai/tvm that referenced this pull request Jun 17, 2021
…che#8084)

* Add GatherND batch_dim support

* adding tests

* test working

* improved reference code

* refactor ref func

* batch dim 2 tests from tf all passed

* batch_dim -> batch_dims

* add example

* minor change

* add onnx test

* fix onnx version

* fix lint

* remove move on batch_dims

* fix pylint

* fix compiler warning

* add shape constraint for batch_dim and update doc

* make the output shape doc clearer
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants