Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

support argmax converter #2291

Merged
merged 4 commits into from
Oct 10, 2023
Merged

support argmax converter #2291

merged 4 commits into from
Oct 10, 2023

Conversation

bowang007
Copy link
Collaborator

Description

Support argmax converter

Checklist:

  • My code follows the style guidelines of this project (You can use the linters)
  • I have performed a self-review of my own code
  • I have commented my code, particularly in hard-to-understand areas and hacks
  • I have made corresponding changes to the documentation
  • I have added tests to verify my fix or my feature
  • New and existing unit tests pass locally with my changes
  • I have added the relevant labels to my PR in so that relevant reviewers are notified

@github-actions github-actions bot added component: api [Python] Issues re: Python API component: conversion Issues re: Conversion stage component: converters Issues re: Specific op converters component: dynamo Issues relating to the `torch.compile` or `torch._dynamo.export` paths component: tests Issues re: Tests labels Sep 5, 2023
@github-actions github-actions bot requested a review from gs-olive September 5, 2023 22:31
@bowang007 bowang007 changed the title support argmax converter support argmax converter [Draft] Sep 5, 2023
Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are some changes that do not conform to Python style guidelines:

--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/impl/argmax.py	2023-09-05 22:31:02.244529+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/impl/argmax.py	2023-09-05 22:33:23.441716+00:00
@@ -23,18 +23,15 @@
    dim: int = 0,
    keep_dim: bool = False,
) -> TRTTensor:
    if not isinstance(input, TRTTensor):
        raise RuntimeError(
-            f"argmax received input {input} that is not part "
-            "of the TensorRT region!"
+            f"argmax received input {input} that is not part " "of the TensorRT region!"
        )
    if dim < 0:
        dim = len(tuple(input.shape)) + dim
    reduce_mask = 1 << dim
    topk_layer = network.add_topk(input, trt.TopKOperation.MAX, 1, reduce_mask)

    set_layer_name(topk_layer, target, name)

    return topk_layer.get_output(1)
-    
-    
--- /home/runner/work/TensorRT/TensorRT/tests/py/dynamo/converters/test_argmax.py	2023-09-05 22:31:02.264529+00:00
+++ /home/runner/work/TensorRT/TensorRT/tests/py/dynamo/converters/test_argmax.py	2023-09-05 22:33:26.764451+00:00
@@ -2,33 +2,23 @@
import torch.nn as nn
from parameterized import parameterized
from torch.testing._internal.common_utils import run_tests
from harness import DispatchTestCase

+
class TestArgmaxConverter(DispatchTestCase):
-    @parameterized.expand(
-            [
-                ("dim_0_keep_dim_false", (3, 4), 0, False)
-            ]
-    )
-
+    @parameterized.expand([("dim_0_keep_dim_false", (3, 4), 0, False)])
    def test_argmax(self, _, input_shape, dim, keep_dim):
        class ArgMax(nn.Module):
            def __init__(self):
                super().__init__()

-            def forward(self, input): 
+            def forward(self, input):
                return torch.argmax(input, dim, keep_dim)
-            

        input = [torch.randn(*input_shape)]

-        self.run_test(
-            ArgMax(),
-            input, 
-            expected_ops={torch.ops.aten.argmax.default}
-        )
+        self.run_test(ArgMax(), input, expected_ops={torch.ops.aten.argmax.default})
+

if __name__ == "__main__":
-    run_tests()  
-
-
+    run_tests()

Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code conforms to C++ style guidelines

@bowang007 bowang007 changed the title support argmax converter [Draft] support argmax converter Sep 22, 2023
Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code conforms to C++ style guidelines

Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code conforms to Python style guidelines

@bowang007 bowang007 force-pushed the argmax_converter_dynamo branch from 9ca9577 to 0047b3d Compare September 22, 2023 04:23
Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code conforms to C++ style guidelines

Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code conforms to Python style guidelines

py/torch_tensorrt/dynamo/conversion/impl/argmax.py Outdated Show resolved Hide resolved
py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py Outdated Show resolved Hide resolved
source_ir: Optional[SourceIR],
name: str,
input: TRTTensor,
dim: int = 0,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Switch to dim: Optional[int] = None since this is the default dim, as per the documentation. Alternatively, if this converter cannot support reducing over all dimensions, you can add a capability_validator to the converter to disallow inputs where the dim is not specified or non-integral.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I used dim: Union[int, None], is that ok?

py/torch_tensorrt/dynamo/conversion/impl/argmax.py Outdated Show resolved Hide resolved
Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code conforms to C++ style guidelines

Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code conforms to Python style guidelines

Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code conforms to C++ style guidelines

Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code conforms to Python style guidelines

@bowang007
Copy link
Collaborator Author

bowang007 commented Oct 7, 2023

Hey @gs-olive I will be OOO next week.
I think this update covers all edge cases.
Please feel free to merge if this is good to go. Thanks!

@bowang007 bowang007 requested a review from gs-olive October 7, 2023 05:30
Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code conforms to C++ style guidelines

Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code conforms to Python style guidelines

@gs-olive gs-olive force-pushed the argmax_converter_dynamo branch from 1f76a5c to ffe53e0 Compare October 10, 2023 01:23
Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code conforms to Python style guidelines

Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code conforms to C++ style guidelines

Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code conforms to Python style guidelines

@gs-olive gs-olive requested a review from apbose October 10, 2023 01:31
Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code conforms to C++ style guidelines

Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code conforms to Python style guidelines

Copy link
Collaborator

@gs-olive gs-olive left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good to me

Copy link
Collaborator

@zewenli98 zewenli98 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good overall, but I left some comments about using our new APIs and small fixes.

py/torch_tensorrt/dynamo/conversion/impl/argmax.py Outdated Show resolved Hide resolved
py/torch_tensorrt/dynamo/conversion/impl/argmax.py Outdated Show resolved Hide resolved
py/torch_tensorrt/dynamo/conversion/impl/argmax.py Outdated Show resolved Hide resolved
py/torch_tensorrt/dynamo/conversion/impl/argmax.py Outdated Show resolved Hide resolved
py/torch_tensorrt/dynamo/conversion/impl/argmax.py Outdated Show resolved Hide resolved
py/torch_tensorrt/dynamo/conversion/impl/argmax.py Outdated Show resolved Hide resolved
@gs-olive gs-olive force-pushed the argmax_converter_dynamo branch from 0bf93c6 to 60c576d Compare October 10, 2023 19:44
Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code conforms to C++ style guidelines

Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code conforms to Python style guidelines

Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code conforms to C++ style guidelines

Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code conforms to Python style guidelines

Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code conforms to Python style guidelines

Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code conforms to C++ style guidelines

Copy link
Collaborator

@zewenli98 zewenli98 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the fix! I found a small bug here. Other looks good to me!

py/torch_tensorrt/dynamo/conversion/impl/argmax.py Outdated Show resolved Hide resolved
- Added regression test
@gs-olive gs-olive force-pushed the argmax_converter_dynamo branch from 60c576d to 668f897 Compare October 10, 2023 21:12
Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code conforms to C++ style guidelines

Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code conforms to Python style guidelines

Copy link
Collaborator

@zewenli98 zewenli98 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

@gs-olive gs-olive merged commit f3f475b into main Oct 10, 2023
17 checks passed
@gs-olive gs-olive deleted the argmax_converter_dynamo branch October 10, 2023 22:20
gs-olive added a commit that referenced this pull request Oct 10, 2023
Signed-off-by: Bo Wang <bowa@nvidia.com>
Co-authored-by: gs-olive <113141689+gs-olive@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
cla signed component: api [Python] Issues re: Python API component: conversion Issues re: Conversion stage component: converters Issues re: Specific op converters component: dynamo Issues relating to the `torch.compile` or `torch._dynamo.export` paths component: tests Issues re: Tests priority: high
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants