Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Relay] Support dynamic indices size in gather_nd and scatter_nd #8105

Merged
merged 9 commits into from
May 30, 2021

Conversation

masahi
Copy link
Member

@masahi masahi commented May 21, 2021

Added shape func for gather_nd op to support dynamic indices size. scatter_nd also works with dynamic indices after dropping some assertions check on shapes.

please review @mbrookhart @jwfromm @comaniac @kevinthesun

One difficulty with gather_nd shape func was that the rank of the output tensor depends on the runtime shape of indices tensor.
At the line

out_shape = output_tensor((kdim + ndim - (mdim + batch_dims),), "int64")
,
if instead I have something like

mdim = indices_shape[0]  # indices_shape is a runtime shape represented as te::Tensor 
out_shape = output_tensor((kdim + ndim - (mdim + batch_dims),), "int64")

different errors are raised depending on the context. For example, an error occurs in the shape func for squeeze that follows gather_dim, at

for i in range(inputs[0].shape[0].value):
. That line is a python range for loop that expects the rank of a shape tensor, inputs[0].shape[0].value, to be a compile time fixed integer. But if the rank of the gather_dim output tensor is calculated like above, the rank would not be a compile time integer (because it depends on the runtime value of shape).

My workaround for this problem is based on the fact that the first axis of indices tensor, whose runtime size is required to compute the output rank, is actually a compile time constant as asserted at

const IntImmNode* mdim = indices->shape[0].as<IntImmNode>();
ICHECK(mdim) << "GatherND needs a static shape for the first axis of indices, got "
. So I simply record this fixed value in the attribute dict and look it up inside shape func. To do this, I had to add a new attribute to gather_dim, which is a bit ad hoc but necessary.

@masahi masahi changed the title [Relay] Support dynamic inputs in gather_nd and scatter_nd [Relay] Support dynamic indices size in gather_nd and scatter_nd May 21, 2021
@masahi
Copy link
Member Author

masahi commented May 26, 2021

pinging for reviews.

python/tvm/relay/op/transform.py Show resolved Hide resolved
include/tvm/relay/attrs/transform.h Outdated Show resolved Hide resolved
@masahi masahi force-pushed the gather_nd_shape_func branch from 55a4430 to 53d9700 Compare May 28, 2021 04:07
@masahi masahi force-pushed the gather_nd_shape_func branch from b745d31 to 06ac205 Compare May 29, 2021 00:58
@masahi masahi merged commit 27e44ee into apache:main May 30, 2021
mehrdadh pushed a commit to mehrdadh/tvm that referenced this pull request Jun 3, 2021
…che#8105)

* add gather_nd shape func

* refactor gather_nd ref funcs

* add dynamic gather_nd test

* gather_dim -> num_indices_per_tuple

* support dynamic scatter nd

* minor fix

* fix pylint

* rename to index_rank and make it Optional

* pylint, do not use -1 for default value
trevor-m pushed a commit to trevor-m/tvm that referenced this pull request Jun 17, 2021
…che#8105)

* add gather_nd shape func

* refactor gather_nd ref funcs

* add dynamic gather_nd test

* gather_dim -> num_indices_per_tuple

* support dynamic scatter nd

* minor fix

* fix pylint

* rename to index_rank and make it Optional

* pylint, do not use -1 for default value
trevor-m pushed a commit to neo-ai/tvm that referenced this pull request Jun 17, 2021
…che#8105)

* add gather_nd shape func

* refactor gather_nd ref funcs

* add dynamic gather_nd test

* gather_dim -> num_indices_per_tuple

* support dynamic scatter nd

* minor fix

* fix pylint

* rename to index_rank and make it Optional

* pylint, do not use -1 for default value
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants