Skip to content

Commit

Permalink
add dynamic support for floor/logical_not/sign/round/isinf/isnan (#2963)
Browse files Browse the repository at this point in the history
  • Loading branch information
lanluo-nvidia authored and cehongwang committed Jul 8, 2024
1 parent 5751d1b commit 7e4da0d
Show file tree
Hide file tree
Showing 7 changed files with 180 additions and 6 deletions.
14 changes: 8 additions & 6 deletions py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -1721,7 +1721,7 @@ def aten_ops_ceil(
)


@dynamo_tensorrt_converter(torch.ops.aten.floor.default)
@dynamo_tensorrt_converter(torch.ops.aten.floor.default, supports_dynamic_shapes=True)
def aten_ops_floor(
ctx: ConversionContext,
target: Target,
Expand All @@ -1738,7 +1738,9 @@ def aten_ops_floor(
)


@dynamo_tensorrt_converter(torch.ops.aten.logical_not.default)
@dynamo_tensorrt_converter(
torch.ops.aten.logical_not.default, supports_dynamic_shapes=True
)
def aten_ops_logical_not(
ctx: ConversionContext,
target: Target,
Expand All @@ -1755,7 +1757,7 @@ def aten_ops_logical_not(
)


@dynamo_tensorrt_converter(torch.ops.aten.sign.default)
@dynamo_tensorrt_converter(torch.ops.aten.sign.default, supports_dynamic_shapes=True)
def aten_ops_sign(
ctx: ConversionContext,
target: Target,
Expand All @@ -1772,7 +1774,7 @@ def aten_ops_sign(
)


@dynamo_tensorrt_converter(torch.ops.aten.round.default)
@dynamo_tensorrt_converter(torch.ops.aten.round.default, supports_dynamic_shapes=True)
def aten_ops_round(
ctx: ConversionContext,
target: Target,
Expand All @@ -1789,7 +1791,7 @@ def aten_ops_round(
)


@dynamo_tensorrt_converter(torch.ops.aten.isinf.default)
@dynamo_tensorrt_converter(torch.ops.aten.isinf.default, supports_dynamic_shapes=True)
def aten_ops_isinf(
ctx: ConversionContext,
target: Target,
Expand All @@ -1806,7 +1808,7 @@ def aten_ops_isinf(
)


@dynamo_tensorrt_converter(torch.ops.aten.isnan.default)
@dynamo_tensorrt_converter(torch.ops.aten.isnan.default, supports_dynamic_shapes=True)
def aten_ops_isnan(
ctx: ConversionContext,
target: Target,
Expand Down
26 changes: 26 additions & 0 deletions tests/py/dynamo/conversion/test_floor_aten.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import torch.nn as nn
from parameterized import parameterized
from torch.testing._internal.common_utils import run_tests
from torch_tensorrt import Input

from .harness import DispatchTestCase

Expand All @@ -26,6 +27,31 @@ def forward(self, input):
inputs,
)

@parameterized.expand(
[
((10,), (11,), (12,)),
((1, 3, 4), (2, 3, 5), (3, 4, 6)),
((2, 3, 4, 5), (3, 5, 4, 5), (4, 6, 4, 5)),
]
)
def test_floor_dynamic_shape(self, min_shape, opt_shape, max_shape):
class floor(nn.Module):
def forward(self, input):
return torch.ops.aten.floor.default(input)

input_specs = [
Input(
dtype=torch.float32,
min_shape=min_shape,
opt_shape=opt_shape,
max_shape=max_shape,
),
]
self.run_test_with_dynamic_shape(
floor(),
input_specs,
)

@parameterized.expand(
[
((10,), torch.int, 0, 5),
Expand Down
44 changes: 44 additions & 0 deletions tests/py/dynamo/conversion/test_isinf_aten.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import torch.nn as nn
from parameterized import parameterized
from torch.testing._internal.common_utils import run_tests
from torch_tensorrt import Input

from .harness import DispatchTestCase

Expand Down Expand Up @@ -37,6 +38,49 @@ def forward(self, input):
inputs,
)

def test_isinf_dynamic_shape_float(self):
class isinf(nn.Module):
def forward(self, input):
return torch.ops.aten.isinf.default(input)

inputs = [
Input(
min_shape=(1, 2, 3),
opt_shape=(3, 2, 3),
max_shape=(5, 3, 3),
dtype=torch.float32,
torch_tensor=torch.tensor(
([[[2.7, float("-inf"), 1.1], [4.7, -2.3, float("inf")]]]),
dtype=torch.float32,
).cuda(),
)
]
self.run_test_with_dynamic_shape(
isinf(),
inputs,
use_example_tensors=False,
)

def test_isinf_dynamic_shape_int(self):
class isinf(nn.Module):
def forward(self, input):
return torch.ops.aten.isinf.default(input)

inputs = [
Input(
min_shape=(1, 2),
opt_shape=(3, 2),
max_shape=(5, 3),
dtype=torch.int,
torch_tensor=torch.tensor(([[-3, 2]]), dtype=torch.int).cuda(),
)
]
self.run_test_with_dynamic_shape(
isinf(),
inputs,
use_example_tensors=False,
)

@parameterized.expand(
[
((10,), torch.int, 0, 5),
Expand Down
24 changes: 24 additions & 0 deletions tests/py/dynamo/conversion/test_isnan_aten.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import torch.nn as nn
from parameterized import parameterized
from torch.testing._internal.common_utils import run_tests
from torch_tensorrt import Input

from .harness import DispatchTestCase

Expand Down Expand Up @@ -39,6 +40,29 @@ def forward(self, input):
inputs,
)

def test_isnan_dynamic_shape_float(self):
class isnan(nn.Module):
def forward(self, input):
return torch.ops.aten.isnan.default(input)

inputs = [
Input(
min_shape=(1, 2, 3),
opt_shape=(3, 2, 3),
max_shape=(5, 3, 3),
dtype=torch.float32,
torch_tensor=torch.tensor(
([[[3.2, float("nan"), 3.1], [float("inf"), 1.1, float("nan")]]]),
dtype=torch.float32,
).cuda(),
)
]
self.run_test_with_dynamic_shape(
isnan(),
inputs,
use_example_tensors=False,
)

@parameterized.expand(
[
(torch.full((2, 2), float("nan"), dtype=torch.float32),),
Expand Down
26 changes: 26 additions & 0 deletions tests/py/dynamo/conversion/test_logical_not_aten.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import torch.nn as nn
from parameterized import parameterized
from torch.testing._internal.common_utils import run_tests
from torch_tensorrt import Input

from .harness import DispatchTestCase

Expand Down Expand Up @@ -60,6 +61,31 @@ def forward(self, input):
inputs,
)

@parameterized.expand(
[
((10,), (11,), (13,)),
((1, 5), (2, 5), (3, 5)),
((2, 3, 4), (2, 3, 5), (3, 4, 6)),
]
)
def test_logical_not_dynamic_shape(self, min_shape, opt_shape, max_shape):
class logical_not(nn.Module):
def forward(self, input):
return torch.ops.aten.logical_not.default(input)

input_specs = [
Input(
dtype=torch.float32,
min_shape=min_shape,
opt_shape=opt_shape,
max_shape=max_shape,
),
]
self.run_test_with_dynamic_shape(
logical_not(),
input_specs,
)


if __name__ == "__main__":
run_tests()
26 changes: 26 additions & 0 deletions tests/py/dynamo/conversion/test_round_aten.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import torch.nn as nn
from parameterized import parameterized
from torch.testing._internal.common_utils import run_tests
from torch_tensorrt import Input

from .harness import DispatchTestCase

Expand All @@ -26,6 +27,31 @@ def forward(self, input):
inputs,
)

@parameterized.expand(
[
((10,), (11,), (12,)),
((1, 3, 4), (2, 3, 5), (3, 4, 6)),
((2, 3, 4, 5), (3, 5, 4, 5), (4, 6, 4, 5)),
]
)
def test_round_dynamic_shape(self, min_shape, opt_shape, max_shape):
class round(nn.Module):
def forward(self, input):
return torch.ops.aten.round.default(input)

input_specs = [
Input(
dtype=torch.float32,
min_shape=min_shape,
opt_shape=opt_shape,
max_shape=max_shape,
),
]
self.run_test_with_dynamic_shape(
round(),
input_specs,
)

@parameterized.expand(
[
((10,), torch.int, 0, 5),
Expand Down
26 changes: 26 additions & 0 deletions tests/py/dynamo/conversion/test_sign_aten.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import torch.nn as nn
from parameterized import parameterized
from torch.testing._internal.common_utils import run_tests
from torch_tensorrt import Input

from .harness import DispatchTestCase

Expand All @@ -26,6 +27,31 @@ def forward(self, input):
inputs,
)

@parameterized.expand(
[
((10,), (11,), (12,)),
((1, 3, 4), (2, 3, 5), (3, 4, 6)),
((2, 3, 4, 5), (3, 5, 4, 5), (4, 6, 4, 5)),
]
)
def test_sign_dynamic_shape(self, min_shape, opt_shape, max_shape):
class sign(nn.Module):
def forward(self, input):
return torch.ops.aten.sign.default(input)

input_specs = [
Input(
dtype=torch.float32,
min_shape=min_shape,
opt_shape=opt_shape,
max_shape=max_shape,
),
]
self.run_test_with_dynamic_shape(
sign(),
input_specs,
)

@parameterized.expand(
[
((10,), torch.int, -2, 2),
Expand Down

0 comments on commit 7e4da0d

Please sign in to comment.