-
Notifications
You must be signed in to change notification settings - Fork 355
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Numpy changes for aten::index converter #2396
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There are some changes that do not conform to Python style guidelines:
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/converter_utils.py 2023-10-13 16:23:31.589620+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/converter_utils.py 2023-10-13 16:26:06.957081+00:00
@@ -178,12 +178,11 @@
rhs_val = cast_trt_tensor(ctx, rhs_val, trt.float32, name)
return [lhs_val, rhs_val]
def broadcastable(
- a: Union[TRTTensor, np.ndarray],
- b: Union[TRTTensor, np.ndarray]
+ a: Union[TRTTensor, np.ndarray], b: Union[TRTTensor, np.ndarray]
) -> bool:
"Check if two tensors are broadcastable according to torch rules"
a_shape = tuple(a.shape)
b_shape = tuple(b.shape)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code conforms to C++ style guidelines
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code conforms to C++ style guidelines
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There are some changes that do not conform to Python style guidelines:
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code conforms to Python style guidelines
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code conforms to C++ style guidelines
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code conforms to Python style guidelines
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code conforms to C++ style guidelines
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code conforms to Python style guidelines
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code conforms to C++ style guidelines
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code conforms to Python style guidelines
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code conforms to C++ style guidelines
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code conforms to Python style guidelines
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code conforms to C++ style guidelines
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code conforms to Python style guidelines
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code conforms to Python style guidelines
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code conforms to C++ style guidelines
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code conforms to C++ style guidelines
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code conforms to Python style guidelines
@@ -127,6 +140,7 @@ def index( | |||
rank = len(input_shape) | |||
adv_indx_count = len(adv_indx_indices) | |||
dim_tensor_list = [] | |||
dim_list = [] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If this is unused, dim_list = []
can be removed
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pending the above comment, it looks good to me!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code conforms to C++ style guidelines
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code conforms to Python style guidelines
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Works well on the SD usecase, and tests pass locally - could a test case be added specifically for this addition, which tests the functionality when all inputs are constants?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi @gs-olive. Regarding adding of the test for this feature, the earlier tests were all numpy indices, where is_numpy is set as True. So those tests should be good for this. Is there any way so that tests can be added so that the indices are ITensors?
Yes, there is. You would just need to pass the indices as inputs to the |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code conforms to C++ style guidelines
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There are some changes that do not conform to Python style guidelines:
--- /home/runner/work/TensorRT/TensorRT/tests/py/dynamo/conversion/test_index_aten.py 2023-10-30 19:32:26.833431+00:00
+++ /home/runner/work/TensorRT/TensorRT/tests/py/dynamo/conversion/test_index_aten.py 2023-10-30 19:35:28.665002+00:00
@@ -24,11 +24,11 @@
input = [torch.randn(2, 2)]
self.run_test(
TestModule(),
input,
)
-
+
def test_index_zero_two_dim_ITensor(self):
class TestModule(nn.Module):
def forward(self, x, index0):
indices = [None, index0]
out = torch.ops.aten.index.Tensor(x, indices)
@@ -56,25 +56,22 @@
input = [torch.randn(2, 2, 2)]
self.run_test(
TestModule(),
input,
)
-
+
def test_index_zero_index_three_dim_ITensor(self):
class TestModule(nn.Module):
def forward(self, x, index0):
indices = [None, index0, None]
out = torch.ops.aten.index.Tensor(x, indices)
return out
input = torch.randn(2, 2, 2)
index0 = torch.randint(0, 1, (1, 1))
index0 = index0.to(torch.int32)
- self.run_test(
- TestModule(),
- [input, index0]
- )
+ self.run_test(TestModule(), [input, index0])
def test_index_zero_index_one_index_two_three_dim(self):
class TestModule(nn.Module):
def __init__(self):
self.index0 = torch.randint(0, 1, (1, 1))
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code conforms to C++ style guidelines
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code conforms to Python style guidelines
@@ -2,11 +2,10 @@ | |||
|
|||
import torch | |||
import torch.nn as nn | |||
from harness import DispatchTestCase |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Requires switch back to .harness
to pass CI
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok corrected, wil wait for CI then will merge.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code conforms to C++ style guidelines
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code conforms to Python style guidelines
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good to me!
This PR addresses the changes for the cases where indices are list of numpy arrays or numpy for aten::index. #2394