-
Notifications
You must be signed in to change notification settings - Fork 355
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
support argmax converter #2291
support argmax converter #2291
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There are some changes that do not conform to Python style guidelines:
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/impl/argmax.py 2023-09-05 22:31:02.244529+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/impl/argmax.py 2023-09-05 22:33:23.441716+00:00
@@ -23,18 +23,15 @@
dim: int = 0,
keep_dim: bool = False,
) -> TRTTensor:
if not isinstance(input, TRTTensor):
raise RuntimeError(
- f"argmax received input {input} that is not part "
- "of the TensorRT region!"
+ f"argmax received input {input} that is not part " "of the TensorRT region!"
)
if dim < 0:
dim = len(tuple(input.shape)) + dim
reduce_mask = 1 << dim
topk_layer = network.add_topk(input, trt.TopKOperation.MAX, 1, reduce_mask)
set_layer_name(topk_layer, target, name)
return topk_layer.get_output(1)
-
-
--- /home/runner/work/TensorRT/TensorRT/tests/py/dynamo/converters/test_argmax.py 2023-09-05 22:31:02.264529+00:00
+++ /home/runner/work/TensorRT/TensorRT/tests/py/dynamo/converters/test_argmax.py 2023-09-05 22:33:26.764451+00:00
@@ -2,33 +2,23 @@
import torch.nn as nn
from parameterized import parameterized
from torch.testing._internal.common_utils import run_tests
from harness import DispatchTestCase
+
class TestArgmaxConverter(DispatchTestCase):
- @parameterized.expand(
- [
- ("dim_0_keep_dim_false", (3, 4), 0, False)
- ]
- )
-
+ @parameterized.expand([("dim_0_keep_dim_false", (3, 4), 0, False)])
def test_argmax(self, _, input_shape, dim, keep_dim):
class ArgMax(nn.Module):
def __init__(self):
super().__init__()
- def forward(self, input):
+ def forward(self, input):
return torch.argmax(input, dim, keep_dim)
-
input = [torch.randn(*input_shape)]
- self.run_test(
- ArgMax(),
- input,
- expected_ops={torch.ops.aten.argmax.default}
- )
+ self.run_test(ArgMax(), input, expected_ops={torch.ops.aten.argmax.default})
+
if __name__ == "__main__":
- run_tests()
-
-
+ run_tests()
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code conforms to C++ style guidelines
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code conforms to C++ style guidelines
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code conforms to Python style guidelines
9ca9577
to
0047b3d
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code conforms to C++ style guidelines
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code conforms to Python style guidelines
source_ir: Optional[SourceIR], | ||
name: str, | ||
input: TRTTensor, | ||
dim: int = 0, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Switch to dim: Optional[int] = None
since this is the default dim
, as per the documentation. Alternatively, if this converter cannot support reducing over all dimensions, you can add a capability_validator
to the converter to disallow inputs where the dim
is not specified or non-integral.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I used dim: Union[int, None]
, is that ok?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code conforms to C++ style guidelines
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code conforms to Python style guidelines
9e62066
to
1f76a5c
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code conforms to C++ style guidelines
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code conforms to Python style guidelines
Hey @gs-olive I will be OOO next week. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code conforms to C++ style guidelines
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code conforms to Python style guidelines
1f76a5c
to
ffe53e0
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code conforms to Python style guidelines
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code conforms to C++ style guidelines
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code conforms to Python style guidelines
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code conforms to C++ style guidelines
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code conforms to Python style guidelines
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good to me
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good overall, but I left some comments about using our new APIs and small fixes.
Signed-off-by: Bo Wang <bowa@nvidia.com>
0bf93c6
to
60c576d
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code conforms to C++ style guidelines
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code conforms to Python style guidelines
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code conforms to C++ style guidelines
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code conforms to Python style guidelines
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code conforms to Python style guidelines
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code conforms to C++ style guidelines
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the fix! I found a small bug here. Other looks good to me!
- Added regression test
60c576d
to
668f897
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code conforms to C++ style guidelines
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code conforms to Python style guidelines
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM!
Signed-off-by: Bo Wang <bowa@nvidia.com> Co-authored-by: gs-olive <113141689+gs-olive@users.noreply.github.com>
Description
Support argmax converter
Checklist: