-
Notifications
You must be signed in to change notification settings - Fork 355
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat: dynamic shape support for aten.select.int #2990
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Mark torch.ops.aten.select.int
with supports_dynamic_shapes=True
output_shape = get_shape_with_dynamic_shape( | ||
ctx, target, source_ir, name, output_shape, input | ||
) | ||
|
||
index_value = np.array(index, dtype=np.int32) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can index itself be dynamic (ITensor)?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Based on the PyTorch docs and the schema, the index
is always an integer and cannot be a list or tuple of integers. The size of the indices_tensor
created with index
and index_value
will always be a scalar. Therefore, it cannot be dynamic.
I have marked |
output_shape = get_shape_with_dynamic_shape( | ||
ctx, target, source_ir, name, output_shape, input | ||
) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If above asserts() are removed and full dynamic shape was used(e.g.(-1,-1,-1)), test worked.
I'm wondering if select on dynamic dim can be supported.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have tested the test you mentioned and confirmed that using -1, -1, -1 also passes the test cases successfully.
However, I didn't remove the asserts() because the test cases fail when the index
is out of range, meaning the index
is larger than the corresponding dimension in the dynamic input shape.
(
"success case",
(1, 1, 1),
(2, 2, 2),
(3, 3, 3),
torch.float,
0,
1,
),
(
"fail case",
(1, 1, 1),
(2, 2, 2),
(3, 3, 3),
torch.float,
0,
3,
),
It seems that modifying the index
like dimension (dim = get_positive_dim(cast(int, dim), ranks)
) when the index
is greater than the size
would solve the issue. Do you have an example that handles it this way?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we cannot check invalid index for dynamic dim. Error will happen in runtime.
Maybe we can check index for only static dim.
if DYNAMIC_DIM != input.shape[dim] and index >= input.shape[dim]:
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I also tested by removing the assert()
and raise RuntimeError()
statements and tested with (-1, -1, -1).
The test case mentioned as a success case
above passes, but the test case mentioned as a fail case
fails. The reason for this is that in the fail case
, the size of the 0-th dimension is 3, but the index
selects positions larger than that (index
=3, 4th position). The current converter does not handle this case where the index
is larger than the size of the dimension specified by dim
, as in the 'fail case'. (Note: It might be possible to handle this using the slice layer, lt (less than) and div functions, but currently, it is handled with an assert statement.)
Therefore, I left the assert()
and raise RuntimeError()
statement without removing it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh! I misunderstood. If the index
is larger than the size of the dimension specified by dim
, it won't work in PyTorch either, so we don't need to handle this case. Therefore, we don't need to consider test cases like the 'fail case' mentioned above. This means that a dynamic shape is supported for all dimensions.
Pytorch example of correct usage of select.int
Pytorch example of IndexError occurring case of select.int
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@keehyuna Thanks for the suggestion!
Since PyTorch already raises an error when the index exceeds the input size, there’s no need for us to check this here. I removed those checks. Additionally, following your suggestion, I’ve added test cases to fully support dynamic shapes.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we need this casting here since it will always be int?
dim = get_positive_dim(cast(int, dim), ranks) |
We could also change the type of dim from
Shape
to int
in the function signature.
indices_tensor = ctx.net.add_constant( | ||
index_value.shape, to_numpy(index_value) | ||
).get_output(0) | ||
indices_tensor = ctx.net.add_constant(index_value.shape, index_value).get_output(0) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can you use get_trt_tensor call here ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the suggestion!
dim
and index
are now int
(not Shape
), and I've changed to using get_trt_tensor
for indices_tensor
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
Description
Support dynamic shapes for aten.select.int. As shown below, since the arguments dim and index for select.int are int and not list, no reshaping is required.
Type of change
Checklist: