-
Notifications
You must be signed in to change notification settings - Fork 355
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[FX] remove op_lowering_disallow_list and format revert #1261
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code conforms to C++ style guidelines
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There are some changes that do not conform to Python style guidelines:
--- py/torch_tensorrt/fx/input_tensor_spec.py 2022-08-12 18:48:48.790461 +0000
+++ py/torch_tensorrt/fx/input_tensor_spec.py 2022-08-12 18:52:24.650475 +0000
@@ -6,14 +6,11 @@
from .utils import get_dynamic_dims
def generate_input_specs(inputs, lower_setting, additional_inputs=None):
# dynamic_batch is TRT only flag.
- if (
- not lower_setting.explicit_batch_dimension
- or lower_setting.dynamic_batch is False
- ):
+ if not lower_setting.explicit_batch_dimension or lower_setting.dynamic_batch is False:
return InputTensorSpec.from_tensors(inputs)
# If we don't have additional inputs, we assume the first dimension
# is the dynamic batch dimension. Otherwise, we use the additional
# inputs to determine the batch dimension.
@@ -33,20 +30,16 @@
for i, j in zip(inputs, additional_inputs):
found_batch_dim = False
for idx, values in enumerate(zip(i.shape, j.shape)):
if values[0] != values[1]:
- assert (
- found_batch_dim is False
- ), f"We've already found a batch dim, {i.shape}, {j.shape}."
+ assert found_batch_dim is False, f"We've already found a batch dim, {i.shape}, {j.shape}."
batch_dims.append(idx)
found_batch_dim = True
if not found_batch_dim:
- raise RuntimeError(
- f"Failed to find batch dimension because shapes are the same, {i.shape}"
- )
+ raise RuntimeError(f"Failed to find batch dimension because shapes are the same, {i.shape}")
return InputTensorSpec.from_tensors_with_dynamic_batch_size(
inputs,
(
0,
@@ -158,13 +151,11 @@
batch_dim
), f"The {i}th tensor (shape: {tensor.shape}) doesn't have the correct batch size: {batch_size}."
shape = list(tensor.shape)
shape[batch_dim] = -1
shape_ranges: List[ShapeRange] = [tuple(tuple(shape[0:batch_dim] + [bs] + shape[batch_dim + 1 :]) for bs in batch_size_range)] * opt_profile_replica # type: ignore[list-item]
- input_specs.append(
- cls(tuple(shape), tensor.dtype, tensor.device, shape_ranges)
- )
+ input_specs.append(cls(tuple(shape), tensor.dtype, tensor.device, shape_ranges))
return input_specs
def to_random_tensor(self):
shape = tuple(self.shape)
--- py/torch_tensorrt/fx/lower.py 2022-08-12 18:48:48.790461 +0000
+++ py/torch_tensorrt/fx/lower.py 2022-08-12 18:52:24.739763 +0000
@@ -77,13 +77,11 @@
lower_setting: LowerSetting
timing_cache_manager: TimingCacheManager
@classmethod
def create(cls, lower_setting):
- timing_cache_manager = TimingCacheManager(
- lower_setting.timing_cache_prefix, lower_setting.save_timing_cache
- )
+ timing_cache_manager = TimingCacheManager(lower_setting.timing_cache_prefix, lower_setting.save_timing_cache)
return LowerTrtInterpreter(lower_setting, timing_cache_manager)
def __call__(self, mod, input, split_name) -> TRTInterpreterResult:
assert self.lower_setting.input_specs, "Can't find input specs for lowering!"
logger.info(f"{split_name=} {self.lower_setting.input_specs=}")
@@ -103,13 +101,11 @@
interpreter = TRTInterpreter(
mod,
input_specs=self.lower_setting.input_specs,
explicit_batch_dimension=self.lower_setting.explicit_batch_dimension,
explicit_precision=self.lower_setting.explicit_precision,
- logger_level=trt.Logger.VERBOSE
- if self.lower_setting.verbose_log
- else trt.Logger.WARNING,
+ logger_level=trt.Logger.VERBOSE if self.lower_setting.verbose_log else trt.Logger.WARNING,
)
interp_result: TRTInterpreterResult = interpreter.run(
max_batch_size=self.lower_setting.max_batch_size,
max_workspace_size=self.lower_setting.max_workspace_size,
@@ -129,13 +125,11 @@
self.timing_cache_manager.update_timing_cache(split_name, timing_cache)
return interp_result
-def default_split_function(
- model: fx.GraphModule, inputs: Input, lower_setting: LowerSetting
-) -> SplitResult:
+def default_split_function(model: fx.GraphModule, inputs: Input, lower_setting: LowerSetting) -> SplitResult:
splitter_setting = TRTSplitterSetting()
splitter_setting.use_implicit_batch_dim = not lower_setting.explicit_batch_dimension
splitter_setting.min_acc_module_size = lower_setting.min_acc_module_size
splitter = TRTSplitter(model, inputs, settings=splitter_setting)
splitter.node_support_preview()
@@ -147,13 +141,11 @@
def default_lower_pass(
create_trt_interpreter: Callable[[LowerSetting], LowerTrtInterpreter],
) -> PassFunc:
- def lower_pass(
- mod: nn.Module, input: Input, lower_setting: LowerSetting, module_name: str
- ) -> nn.Module:
+ def lower_pass(mod: nn.Module, input: Input, lower_setting: LowerSetting, module_name: str) -> nn.Module:
"""
Create a module transformation pass which lowers an `fx.GraphModule` into a
`TRTModule`
"""
interpreter = create_trt_interpreter(lower_setting)
@@ -223,21 +215,13 @@
inputs: Input,
additional_inputs: Optional[Input] = None,
) -> nn.Module:
module.eval()
- if (
- self.lower_pass_manager_builder.lower_setting.lower_precision
- == LowerPrecision.FP16
- ):
+ if self.lower_pass_manager_builder.lower_setting.lower_precision == LowerPrecision.FP16:
module.half()
- inputs = tuple(
- x.half() if x is not None and x.dtype == torch.float32 else x
- for x in inputs
- )
- pm = self.lower_pass_manager_builder.build_trt_lower_pipeline(
- inputs, additional_inputs
- )
+ inputs = tuple(x.half() if x is not None and x.dtype == torch.float32 else x for x in inputs)
+ pm = self.lower_pass_manager_builder.build_trt_lower_pipeline(inputs, additional_inputs)
lower_result = pm(module)
return lower_result
--- py/torch_tensorrt/fx/passes/lower_pass_manager_builder.py 2022-08-12 18:48:48.790461 +0000
+++ py/torch_tensorrt/fx/passes/lower_pass_manager_builder.py 2022-08-12 18:52:24.961256 +0000
@@ -35,23 +35,17 @@
# >>> with FUSE_PASSES_POST_OBSERVER.add(print_module_and_input):
# >>> # print_module_and_input will be called right after the fuse passes
# >>> lower(module, sample_input)
# Observer for the model after the fuse passes.
-FUSE_PASSES_POST_OBSERVER: Observer[Callable[[nn.Module, Input], None]] = Observer(
- "FUSE_PASSES_POST_OBSERVER"
-)
+FUSE_PASSES_POST_OBSERVER: Observer[Callable[[nn.Module, Input], None]] = Observer("FUSE_PASSES_POST_OBSERVER")
# Observer for the TRT split submodules before lowering
-LOWER_SPLIT_PRE_OBSERVER: Observer[Callable[[str, nn.Module, Input], None]] = Observer(
- "LOWER_SPLIT_PRE_OBSERVER"
-)
+LOWER_SPLIT_PRE_OBSERVER: Observer[Callable[[str, nn.Module, Input], None]] = Observer("LOWER_SPLIT_PRE_OBSERVER")
# Observer for the TRT split submodules after lowering
-LOWER_SPLIT_POST_OBSERVER: Observer[Callable[[str, nn.Module, Input], None]] = Observer(
- "LOWER_SPLIT_POST_OBSERVER"
-)
+LOWER_SPLIT_POST_OBSERVER: Observer[Callable[[str, nn.Module, Input], None]] = Observer("LOWER_SPLIT_POST_OBSERVER")
# ----------------------------------------------------------------------
def wrapper(fn: Callable, input) -> Callable:
@wraps(fn)
@@ -103,22 +97,16 @@
passes.append(wrapper(p, self._input))
for p in self.lower_setting.lower_basic_fuse_pass.passes:
passes.append(wrapper(p, self._input))
passes.append(inplace_wrapper(common_subexpression_elimination))
- passes.append(
- inplace_wrapper(lambda m: FUSE_PASSES_POST_OBSERVER.observe(m, self._input))
- )
+ passes.append(inplace_wrapper(lambda m: FUSE_PASSES_POST_OBSERVER.observe(m, self._input)))
return PassManager.build_from_passlist(passes)
def _split_pass(self) -> PassManager:
- passes = [
- partial(
- self._split_func, inputs=self._input, lower_setting=self.lower_setting
- )
- ]
+ passes = [partial(self._split_func, inputs=self._input, lower_setting=self.lower_setting)]
passes.append(
inplace_wrapper(
lambda split_result: remove_duplicate_output_args(
split_result.split_module, split_result.submodule_inputs.keys()
)
@@ -152,21 +140,15 @@
lowering_start_time = datetime.datetime.now()
self.lower_setting.input_specs = generate_input_specs(
submod_inputs,
self.lower_setting,
- additional_submodule_inputs[submod_name]
- if additional_submodule_inputs
- else None,
+ additional_submodule_inputs[submod_name] if additional_submodule_inputs else None,
)
- lowered_module = self._lower_func(
- submod, submod_inputs, self.lower_setting, submod_name
- )
+ lowered_module = self._lower_func(submod, submod_inputs, self.lower_setting, submod_name)
setattr(split_result.split_module, submod_name, lowered_module)
- LOWER_SPLIT_POST_OBSERVER.observe(
- submod_name, lowered_module, submod_inputs
- )
+ LOWER_SPLIT_POST_OBSERVER.observe(submod_name, lowered_module, submod_inputs)
_LOGGER.info(
f"Lowering submodule {submod_name} elapsed time {datetime.datetime.now() - lowering_start_time}"
)
return split_result.split_module
@@ -184,28 +166,22 @@
# Only acc submodules will be lowered.
if not submod_name.startswith(split_result.non_acc_submodule_prefix):
_LOGGER.info(f"Now lowering submodule {submod_name}")
lowering_start_time = datetime.datetime.now()
- lowered_module = self._lower_func(
- submod, submod_inputs, self.lower_setting, submod_name
- )
+ lowered_module = self._lower_func(submod, submod_inputs, self.lower_setting, submod_name)
setattr(split_result.split_module, submod_name, lowered_module)
- LOWER_SPLIT_POST_OBSERVER.observe(
- submod_name, lowered_module, submod_inputs
- )
+ LOWER_SPLIT_POST_OBSERVER.observe(submod_name, lowered_module, submod_inputs)
_LOGGER.info(
f"Lowering submodule {submod_name} elapsed time {datetime.datetime.now() - lowering_start_time}"
)
return split_result.split_module
return PassManager.build_from_passlist([lower_func])
- def build_trt_lower_pipeline(
- self, input: Input, additional_input: Optional[Input] = None
- ) -> PassManager:
+ def build_trt_lower_pipeline(self, input: Input, additional_input: Optional[Input] = None) -> PassManager:
self._input = input
self._additional_input = additional_input
passes = []
passes.append(self._const_fold_pass())
@@ -214,13 +190,11 @@
passes.append(self._trt_lower_pass())
pm = PassManager.build_from_passlist(passes)
return pm
- def build_default_lower_pipeline(
- self, input: Input, additional_input: Optional[Input] = None
- ) -> PassManager:
+ def build_default_lower_pipeline(self, input: Input, additional_input: Optional[Input] = None) -> PassManager:
self._input = input
self._additional_input = additional_input
passes = []
passes.append(self._const_fold_pass())
--- py/torch_tensorrt/fx/test/converters/acc_op/test_avgpool.py 2022-08-12 18:48:48.790461 +0000
+++ py/torch_tensorrt/fx/test/converters/acc_op/test_avgpool.py 2022-08-12 18:52:25.154686 +0000
@@ -27,13 +27,11 @@
count_include_pad=True,
):
class TestModule(torch.nn.Module):
def __init__(self):
super().__init__()
- self.avg_pool = torch.nn.AvgPool1d(
- kernel_size, stride, padding, ceil_mode, count_include_pad
- )
+ self.avg_pool = torch.nn.AvgPool1d(kernel_size, stride, padding, ceil_mode, count_include_pad)
def forward(self, x):
return self.avg_pool(x)
inputs = [torch.randn(1, 3, 224)]
@@ -60,13 +58,11 @@
count_include_pad=True,
):
class TestModule(torch.nn.Module):
def __init__(self):
super().__init__()
- self.avg_pool = torch.nn.AvgPool1d(
- kernel_size, stride, padding, ceil_mode, count_include_pad
- )
+ self.avg_pool = torch.nn.AvgPool1d(kernel_size, stride, padding, ceil_mode, count_include_pad)
def forward(self, x):
return self.avg_pool(x)
input_specs = [
@@ -75,13 +71,11 @@
dtype=torch.float32,
shape_ranges=[((1, 3, 3), (3, 3, 3), (3, 3, 3))],
),
]
- self.run_test_with_dynamic_shape(
- TestModule(), input_specs, expected_ops={acc_ops.avg_pool1d}
- )
+ self.run_test_with_dynamic_shape(TestModule(), input_specs, expected_ops={acc_ops.avg_pool1d})
def test_avg_pool2d_with_dynamic_shape_four_dimensions(
self,
test_name="default",
kernel_size=1,
@@ -112,13 +106,11 @@
dtype=torch.float32,
shape_ranges=[((1, 1, 1, 1), (3, 3, 3, 3), (5, 5, 5, 5))],
),
]
- self.run_test_with_dynamic_shape(
- TestModule(), input_specs, expected_ops={acc_ops.avg_pool2d}
- )
+ self.run_test_with_dynamic_shape(TestModule(), input_specs, expected_ops={acc_ops.avg_pool2d})
@parameterized.expand(
[
("default", 1),
("kernal_size", 3),
@@ -254,12 +246,10 @@
dtype=torch.float32,
shape_ranges=[((1, 1, 1, 1), (3, 3, 3, 3), (5, 5, 5, 5))],
),
]
- self.run_test_with_dynamic_shape(
- TestModule(), input_specs, expected_ops={acc_ops.avg_pool2d}
- )
+ self.run_test_with_dynamic_shape(TestModule(), input_specs, expected_ops={acc_ops.avg_pool2d})
if __name__ == "__main__":
run_tests()
--- py/torch_tensorrt/fx/test/converters/acc_op/test_batchnorm.py 2022-08-12 18:48:48.790461 +0000
+++ py/torch_tensorrt/fx/test/converters/acc_op/test_batchnorm.py 2022-08-12 18:52:25.193116 +0000
@@ -32,13 +32,11 @@
dtype=torch.float32,
shape_ranges=[((2, 3, 5), (6, 3, 5), (10, 3, 5))],
),
]
- self.run_test_with_dynamic_shape(
- TestModule(), input_specs, expected_ops={acc_ops.batch_norm}
- )
+ self.run_test_with_dynamic_shape(TestModule(), input_specs, expected_ops={acc_ops.batch_norm})
def test_batchnorm_with_dynamic_shape(self):
class TestModule(torch.nn.Module):
def __init__(self):
super().__init__()
@@ -53,13 +51,11 @@
dtype=torch.float32,
shape_ranges=[((1, 3, 1, 1), (1, 3, 5, 5), (2, 3, 10, 10))],
),
]
- self.run_test_with_dynamic_shape(
- TestModule(), input_specs, expected_ops={acc_ops.batch_norm}
- )
+ self.run_test_with_dynamic_shape(TestModule(), input_specs, expected_ops={acc_ops.batch_norm})
# Testing with shape=(-1, -1, -1, -1) results in AssertionError: Channel dim can't be dynamic for batch norm.
if __name__ == "__main__":
--- py/torch_tensorrt/fx/test/converters/acc_op/test_clamp.py 2022-08-12 18:48:48.790461 +0000
+++ py/torch_tensorrt/fx/test/converters/acc_op/test_clamp.py 2022-08-12 18:52:25.320784 +0000
@@ -51,12 +51,10 @@
dtype=torch.float32,
shape_ranges=[((1, 1, 3, 3), (3, 3, 3, 3), (5, 5, 3, 3))],
),
]
- self.run_test_with_dynamic_shape(
- TestModule(), input_specs, expected_ops={acc_ops.clamp}
- )
+ self.run_test_with_dynamic_shape(TestModule(), input_specs, expected_ops={acc_ops.clamp})
if __name__ == "__main__":
run_tests()
--- py/torch_tensorrt/fx/test/converters/acc_op/test_convolution.py 2022-08-12 18:48:48.790461 +0000
+++ py/torch_tensorrt/fx/test/converters/acc_op/test_convolution.py 2022-08-12 18:52:25.418225 +0000
@@ -27,13 +27,11 @@
bias=True,
):
class TestModule(torch.nn.Module):
def __init__(self):
super().__init__()
- self.conv = torch.nn.Conv1d(
- 3, 6, kernel_size, stride, padding, dilation, groups, bias
- )
+ self.conv = torch.nn.Conv1d(3, 6, kernel_size, stride, padding, dilation, groups, bias)
def forward(self, x):
return self.conv(x)
inputs = [torch.randn(1, 3, 32)]
@@ -60,13 +58,11 @@
bias=True,
):
class TestModule(torch.nn.Module):
def __init__(self):
super().__init__()
- self.conv = torch.nn.Conv1d(
- 3, 6, kernel_size, stride, padding, dilation, groups, bias
- )
+ self.conv = torch.nn.Conv1d(3, 6, kernel_size, stride, padding, dilation, groups, bias)
def forward(self, x):
return self.conv(x)
input_specs = [
@@ -75,13 +71,11 @@
dtype=torch.float32,
shape_ranges=[((1, 3, 3), (3, 3, 3), (5, 3, 3))],
),
]
- self.run_test_with_dynamic_shape(
- TestModule(), input_specs, expected_ops={acc_ops.conv1d}
- )
+ self.run_test_with_dynamic_shape(TestModule(), input_specs, expected_ops={acc_ops.conv1d})
@parameterized.expand(
[
("default", 1),
param("no_bias", 1, bias=False),
@@ -102,13 +96,11 @@
bias=True,
):
class TestModule(torch.nn.Module):
def __init__(self):
super().__init__()
- self.conv = torch.nn.Conv2d(
- 3, 6, kernel_size, stride, padding, dilation, groups, bias
- )
+ self.conv = torch.nn.Conv2d(3, 6, kernel_size, stride, padding, dilation, groups, bias)
def forward(self, x):
return self.conv(x)
inputs = [torch.randn(1, 3, 32, 32)]
@@ -131,13 +123,11 @@
shape=(-1, 3, -1, -1),
dtype=torch.float32,
shape_ranges=[((1, 3, 1, 1), (1, 3, 4, 4), (32, 3, 128, 128))],
),
]
- self.run_test_with_dynamic_shape(
- TestModule(), input_specs, expected_ops={acc_ops.conv2d}
- )
+ self.run_test_with_dynamic_shape(TestModule(), input_specs, expected_ops={acc_ops.conv2d})
@parameterized.expand(
[
("default", 1),
param("no_bias", 1, bias=False),
@@ -158,13 +148,11 @@
bias=True,
):
class TestModule(torch.nn.Module):
def __init__(self):
super().__init__()
- self.conv = torch.nn.Conv3d(
- 3, 6, kernel_size, stride, padding, dilation, groups, bias
- )
+ self.conv = torch.nn.Conv3d(3, 6, kernel_size, stride, padding, dilation, groups, bias)
def forward(self, x):
return self.conv(x)
inputs = [torch.randn(1, 3, 32, 32, 32)]
@@ -187,12 +175,10 @@
shape=(-1, 3, -1, -1, -1),
dtype=torch.float32,
shape_ranges=[((1, 3, 1, 1, 1), (1, 3, 4, 4, 4), (8, 3, 32, 32, 32))],
),
]
- self.run_test_with_dynamic_shape(
- TestModule(), input_specs, expected_ops={acc_ops.conv3d}
- )
+ self.run_test_with_dynamic_shape(TestModule(), input_specs, expected_ops={acc_ops.conv3d})
if __name__ == "__main__":
run_tests()
--- py/torch_tensorrt/fx/test/converters/acc_op/test_gelu.py 2022-08-12 18:48:48.790461 +0000
+++ py/torch_tensorrt/fx/test/converters/acc_op/test_gelu.py 2022-08-12 18:52:25.578605 +0000
@@ -5,13 +5,11 @@
import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops
from torch.testing._internal.common_utils import run_tests
from torch_tensorrt.fx.tools.common_fx2trt import AccTestCase, InputTensorSpec
-@unittest.skip(
- reason="Could not find CustomGeluPluginDynamic. Enable it once we upgrade TRT to 8.4"
-)
+@unittest.skip(reason="Could not find CustomGeluPluginDynamic. Enable it once we upgrade TRT to 8.4")
class TestGELU(AccTestCase):
def test_gelu(self):
class TestModule(nn.Module):
def forward(self, x):
return nn.functional.gelu(x)
@@ -34,13 +32,11 @@
shape=(-1, -1, -1),
dtype=torch.float32,
shape_ranges=[((1, 1, 1), (1, 2, 3), (3, 3, 3))],
),
]
- self.run_test_with_dynamic_shape(
- TestModule(), input_specs, expected_ops={acc_ops.gelu}
- )
+ self.run_test_with_dynamic_shape(TestModule(), input_specs, expected_ops={acc_ops.gelu})
def test_gelu_with_dynamic_shape_four_dimensions(self):
class TestModule(nn.Module):
def forward(self, x):
return nn.functional.gelu(x)
@@ -51,12 +47,10 @@
dtype=torch.float32,
shape_ranges=[((1, 1, 1, 1), (1, 2, 3, 3), (3, 3, 3, 3))],
),
]
- self.run_test_with_dynamic_shape(
- TestModule(), input_specs, expected_ops={acc_ops.gelu}
- )
+ self.run_test_with_dynamic_shape(TestModule(), input_specs, expected_ops={acc_ops.gelu})
if __name__ == "__main__":
run_tests()
--- py/torch_tensorrt/fx/test/converters/acc_op/test_interpolate.py 2022-08-12 18:48:48.790461 +0000
+++ py/torch_tensorrt/fx/test/converters/acc_op/test_interpolate.py 2022-08-12 18:52:25.793707 +0000
@@ -131,12 +131,10 @@
dtype=torch.float32,
shape_ranges=[((1, 1, 1, 1), (1, 2, 3, 3), (3, 3, 3, 3))],
),
]
- self.run_test_with_dynamic_shape(
- Interpolate(), input_specs, expected_ops={acc_ops.interpolate}
- )
+ self.run_test_with_dynamic_shape(Interpolate(), input_specs, expected_ops={acc_ops.interpolate})
if __name__ == "__main__":
run_tests()
--- py/torch_tensorrt/fx/test/converters/acc_op/test_matmul.py 2022-08-12 18:48:48.790461 +0000
+++ py/torch_tensorrt/fx/test/converters/acc_op/test_matmul.py 2022-08-12 18:52:26.083066 +0000
@@ -71,15 +71,11 @@
class MatMul(nn.Module):
def forward(self, input, other):
return torch.matmul(input, other)
inputs = [torch.randn(*input_shape), torch.randn(*other_shape)]
- test_implicit_batch_dim = (
- input_shape[0] == other_shape[0]
- and len(input_shape) > 2
- and len(other_shape) > 2
- )
+ test_implicit_batch_dim = input_shape[0] == other_shape[0] and len(input_shape) > 2 and len(other_shape) > 2
self.run_test(
MatMul(),
inputs,
expected_ops={acc_ops.matmul},
test_implicit_batch_dim=test_implicit_batch_dim,
@@ -106,12 +102,10 @@
dtype=torch.float32,
shape_ranges=[((1, 1, 3, 3), (9, 4, 3, 3), (9, 4, 3, 3))],
),
]
- self.run_test_with_dynamic_shape(
- Matmul(), input_specs, expected_ops={acc_ops.matmul}
- )
+ self.run_test_with_dynamic_shape(Matmul(), input_specs, expected_ops={acc_ops.matmul})
if __name__ == "__main__":
run_tests()
--- py/torch_tensorrt/fx/test/converters/acc_op/test_max.py 2022-08-12 18:48:48.790461 +0000
+++ py/torch_tensorrt/fx/test/converters/acc_op/test_max.py 2022-08-12 18:52:26.198919 +0000
@@ -102,13 +102,11 @@
dtype=torch.float32,
shape_ranges=[((1, 1, 5, 5), (2, 3, 5, 5), (2, 3, 5, 5))],
),
]
- self.run_test_with_dynamic_shape(
- MaxDimReduce(), input_specs, expected_ops={acc_ops.max_dim_reduce}
- )
+ self.run_test_with_dynamic_shape(MaxDimReduce(), input_specs, expected_ops={acc_ops.max_dim_reduce})
def test_max_full_reduce(
self,
):
class MaxFullReduce(torch.nn.Module):
@@ -124,13 +122,11 @@
dtype=torch.float32,
shape_ranges=[((1, 1, 5, 5), (2, 3, 5, 5), (2, 3, 5, 5))],
),
]
- self.run_test_with_dynamic_shape(
- MaxFullReduce(), input_specs, expected_ops={acc_ops.max_full_reduce}
- )
+ self.run_test_with_dynamic_shape(MaxFullReduce(), input_specs, expected_ops={acc_ops.max_full_reduce})
def test_max_method(self):
class MaxMethod(torch.nn.Module):
def __init__(self):
super().__init__()
@@ -149,12 +145,10 @@
dtype=torch.float32,
shape_ranges=[((1, 1, 5, 5), (2, 3, 5, 5), (2, 3, 5, 5))],
),
]
- self.run_test_with_dynamic_shape(
- MaxMethod(), input_specs, expected_ops={acc_ops.maximum}
- )
+ self.run_test_with_dynamic_shape(MaxMethod(), input_specs, expected_ops={acc_ops.maximum})
if __name__ == "__main__":
run_tests()
--- py/torch_tensorrt/fx/test/converters/acc_op/test_min.py 2022-08-12 18:48:48.790461 +0000
+++ py/torch_tensorrt/fx/test/converters/acc_op/test_min.py 2022-08-12 18:52:26.358565 +0000
@@ -101,13 +101,11 @@
dtype=torch.float32,
shape_ranges=[((1, 1, 5, 5), (2, 3, 5, 5), (2, 3, 5, 5))],
),
]
- self.run_test_with_dynamic_shape(
- MinDimReduce(), input_specs, expected_ops={acc_ops.min_dim_reduce}
- )
+ self.run_test_with_dynamic_shape(MinDimReduce(), input_specs, expected_ops={acc_ops.min_dim_reduce})
def test_min_full_reduce(
self,
):
class MinFullReduce(torch.nn.Module):
@@ -123,13 +121,11 @@
dtype=torch.float32,
shape_ranges=[((1, 1, 5, 5), (2, 3, 5, 5), (2, 3, 5, 5))],
),
]
- self.run_test_with_dynamic_shape(
- MinFullReduce(), input_specs, expected_ops={acc_ops.min_full_reduce}
- )
+ self.run_test_with_dynamic_shape(MinFullReduce(), input_specs, expected_ops={acc_ops.min_full_reduce})
def test_min_method(self):
class MinMethod(torch.nn.Module):
def __init__(self):
super().__init__()
@@ -148,12 +144,10 @@
dtype=torch.float32,
shape_ranges=[((1, 1, 5, 5), (2, 3, 5, 5), (2, 3, 5, 5))],
),
]
- self.run_test_with_dynamic_shape(
- MinMethod(), input_specs, expected_ops={acc_ops.minimum}
- )
+ self.run_test_with_dynamic_shape(MinMethod(), input_specs, expected_ops={acc_ops.minimum})
if __name__ == "__main__":
run_tests()
--- py/torch_tensorrt/fx/test/converters/acc_op/test_narrow.py 2022-08-12 18:48:48.790461 +0000
+++ py/torch_tensorrt/fx/test/converters/acc_op/test_narrow.py 2022-08-12 18:52:26.408067 +0000
@@ -23,13 +23,11 @@
dtype=torch.float32,
shape_ranges=[((1, 3, 5, 5), (2, 3, 5, 5), (2, 3, 5, 5))],
),
]
- self.run_test_with_dynamic_shape(
- Narrow(), input_specs, expected_ops={acc_ops.slice_tensor}
- )
+ self.run_test_with_dynamic_shape(Narrow(), input_specs, expected_ops={acc_ops.slice_tensor})
class TestNarrowConverter(AccTestCase):
@parameterized.expand(
[
--- py/torch_tensorrt/fx/converters/acc_ops_converters.py 2022-08-12 18:48:48.790461 +0000
+++ py/torch_tensorrt/fx/converters/acc_ops_converters.py 2022-08-12 18:52:26.424854 +0000
@@ -34,14 +34,11 @@
kwargs: Dict[str, Argument],
name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
input_val = kwargs["input"]
if not isinstance(input_val, TRTTensor):
- raise RuntimeError(
- f"Conv received input {input_val} that is not part "
- "of the TensorRT region!"
- )
+ raise RuntimeError(f"Conv received input {input_val} that is not part " "of the TensorRT region!")
# Process 1d input with unsqueeze -> conv2d -> squeeze to calculated conv1d
unsqueeze_layer = network.add_shuffle(input=input_val)
unsqueeze_layer.reshape_dims = tuple([*input_val.shape, 1])
set_layer_name(unsqueeze_layer, target, name + "_unsqueeze")
@@ -52,13 +49,11 @@
# for now we'll assume bias is constant Tensor or None,
# and bias being ITensor is not supported in TensorRT api
# right now
if kwargs["bias"] is not None and not isinstance(kwargs["bias"], torch.Tensor):
- raise RuntimeError(
- f"linear {name} has bias of type {type(kwargs['bias'])}, Expect Optional[Tenosr]"
- )
+ raise RuntimeError(f"linear {name} has bias of type {type(kwargs['bias'])}, Expect Optional[Tenosr]")
bias = to_numpy(kwargs["bias"]) # type: ignore[arg-type]
if bias is not None:
bias = bias[None]
weight = kwargs["weight"]
@@ -82,13 +77,11 @@
)
layer.set_input(1, weight)
else:
if not isinstance(kwargs["weight"], torch.Tensor):
- raise RuntimeError(
- f"linear {name} has weight of type {type(kwargs['weight'])}, Expect Optional[Tenosr]"
- )
+ raise RuntimeError(f"linear {name} has weight of type {type(kwargs['weight'])}, Expect Optional[Tenosr]")
weight = to_numpy(weight)
weight = np.expand_dims(weight, -1)
layer = network.add_convolution_nd(
input=input_val,
num_output_maps=weight.shape[0],
@@ -126,25 +119,20 @@
name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
input_val = kwargs["input"]
if not isinstance(input_val, TRTTensor):
- raise RuntimeError(
- f"Conv received input {input_val} that is not part "
- "of the TensorRT region!"
- )
+ raise RuntimeError(f"Conv received input {input_val} that is not part " "of the TensorRT region!")
if has_dynamic_shape(input_val.shape):
assert input_val.shape[1] != -1, "Channel dim can't be dynamic for convolution."
# for now we'll assume bias is constant Tensor or None,
# and bias being ITensor is not supported in TensorRT api
# right now
if kwargs["bias"] is not None and not isinstance(kwargs["bias"], torch.Tensor):
- raise RuntimeError(
- f"linear {name} has bias of type {type(kwargs['bias'])}, Expect Optional[Tenosr]"
- )
+ raise RuntimeError(f"linear {name} has bias of type {type(kwargs['bias'])}, Expect Optional[Tenosr]")
bias = to_numpy(kwargs["bias"]) # type: ignore[arg-type]
if network.has_explicit_precision:
weight = get_trt_tensor(network, kwargs["weight"], f"{name}_weight")
weight_shape = tuple(kwargs["weight"].shape) # type: ignore[union-attr]
@@ -160,13 +148,11 @@
)
layer.set_input(1, weight)
else:
if not isinstance(kwargs["weight"], torch.Tensor):
- raise RuntimeError(
- f"linear {name} has weight of type {type(kwargs['weight'])}, Expect Optional[Tenosr]"
- )
+ raise RuntimeError(f"linear {name} has weight of type {type(kwargs['weight'])}, Expect Optional[Tenosr]")
weight = to_numpy(kwargs["weight"])
layer = network.add_convolution_nd(
input=input_val,
num_output_maps=weight.shape[0],
kernel_shape=weight.shape[2:],
@@ -194,27 +180,20 @@
name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
input_val = kwargs["input"]
if not isinstance(input_val, TRTTensor):
- raise RuntimeError(
- f"Transpose conv received input {input_val} that is not part "
- "of the TensorRT region!"
- )
+ raise RuntimeError(f"Transpose conv received input {input_val} that is not part " "of the TensorRT region!")
if has_dynamic_shape(input_val.shape):
- assert (
- input_val.shape[1] != -1
- ), "Channel dim can't be dynamic for transpose convolution."
+ assert input_val.shape[1] != -1, "Channel dim can't be dynamic for transpose convolution."
# for now we'll assume bias is constant Tensor or None,
# and bias being ITensor is not supported in TensorRT api
# right now
if kwargs["bias"] is not None and not isinstance(kwargs["bias"], torch.Tensor):
- raise RuntimeError(
- f"ConvTranspose {name} has bias of type {type(kwargs['bias'])}, Expect Optional[Tensor]"
- )
+ raise RuntimeError(f"ConvTranspose {name} has bias of type {type(kwargs['bias'])}, Expect Optional[Tensor]")
bias = to_numpy(kwargs["bias"]) # type: ignore[arg-type]
if network.has_explicit_precision:
weight = get_trt_tensor(network, kwargs["weight"], f"{name}_weight")
weight_shape = tuple(kwargs["weight"].shape) # type: ignore[union-attr]
@@ -232,13 +211,11 @@
)
layer.set_input(1, weight)
else:
if not isinstance(kwargs["weight"], torch.Tensor):
- raise RuntimeError(
- f"conv {name} has weight of type {type(kwargs['weight'])}, Expect Optional[Tensor]"
- )
+ raise RuntimeError(f"conv {name} has weight of type {type(kwargs['weight'])}, Expect Optional[Tensor]")
weight = to_numpy(kwargs["weight"])
# nn.ConvTranspose2d/3d weight size is (in_channels, out_channels/groups, kernel_0, kernel_1, [kernel_2])
layer = network.add_deconvolution_nd(
input=input_val,
num_output_maps=weight.shape[1] * kwargs["groups"],
@@ -270,29 +247,20 @@
mode = kwargs["mode"]
value = kwargs["value"] if kwargs["value"] is not None else 0
rank = len(input_val.shape) # type: ignore[union-attr]
if not isinstance(input_val, TRTTensor):
- raise RuntimeError(
- f"pad received input {input_val} that is not part "
- "of the TensorRT region!"
- )
+ raise RuntimeError(f"pad received input {input_val} that is not part " "of the TensorRT region!")
if mode != "constant":
- raise RuntimeError(
- f"Currently we only support constant mode for pad, got {mode}."
- )
+ raise RuntimeError(f"Currently we only support constant mode for pad, got {mode}.")
if len(pad) / 2 > rank:
- raise RuntimeError(
- f"Trying to pad last {len(pad) / 2} dimension but the input only has {rank} dimension."
- )
+ raise RuntimeError(f"Trying to pad last {len(pad) / 2} dimension but the input only has {rank} dimension.")
if value != 0:
- raise RuntimeError(
- f"Currently we only support padding value of 0, got {value}."
- )
+ raise RuntimeError(f"Currently we only support padding value of 0, got {value}.")
if len(pad) > 4:
raise RuntimeError("Currently we only support padding last two dimensions.")
pre_padding = tuple(pad[len(pad) - i - 2] for i in range(0, len(pad), 2))
@@ -320,38 +288,28 @@
mode = kwargs["mode"]
value = kwargs["value"] if kwargs["value"] is not None else 0
rank = len(input_val.shape) # type: ignore[union-attr]
if not isinstance(input_val, TRTTensor):
- raise RuntimeError(
- f"pad received input {input_val} that is not part "
- "of the TensorRT region!"
- )
+ raise RuntimeError(f"pad received input {input_val} that is not part " "of the TensorRT region!")
if mode != "constant":
- raise RuntimeError(
- f"Currently we only support constant mode for pad, got {mode}."
- )
+ raise RuntimeError(f"Currently we only support constant mode for pad, got {mode}.")
if len(pad) / 2 > rank:
- raise RuntimeError(
- f"Trying to pad last {len(pad) / 2} dimension but the input only has {rank} dimension."
- )
+ raise RuntimeError(f"Trying to pad last {len(pad) / 2} dimension but the input only has {rank} dimension.")
# cast value to TRTensor
dt = torch_dtype_from_trt(input_val.dtype)
value = 0 if value == None else value
- value_const = get_trt_tensor(
- network, torch.tensor([value], dtype=dt), f"{name}_value"
- )
+ value_const = get_trt_tensor(network, torch.tensor([value], dtype=dt), f"{name}_value")
input_shape = input_val.shape
pre_start = tuple(i - 1 for i in input_shape)
prefix_len = len(input_shape) - len(pad) // 2
pre_shape = tuple(
- input_shape[i] + (pad[-(i - prefix_len) * 2 - 2] if i >= prefix_len else 0)
- for i in range(0, len(input_shape))
+ input_shape[i] + (pad[-(i - prefix_len) * 2 - 2] if i >= prefix_len else 0) for i in range(0, len(input_shape))
)
pre_stride = [-1] * len(input_shape)
layer = network.add_slice(
input_val,
@@ -374,12 +332,11 @@
transpose_output = layer.get_output(0)
shape = transpose_output.shape
post_start = tuple([0] * len(shape))
post_shape = tuple(
- shape[i] + (pad[-(i - prefix_len) * 2 - 1] if i >= prefix_len else 0)
- for i in range(0, len(shape))
+ shape[i] + (pad[-(i - prefix_len) * 2 - 1] if i >= prefix_len else 0) for i in range(0, len(shape))
)
post_stride = tuple([1] * len(shape))
layer = network.add_slice(transpose_output, post_start, post_shape, post_stride)
layer.set_input(4, value_const)
@@ -397,22 +354,15 @@
name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
input_val = kwargs["input"]
if not isinstance(input_val, TRTTensor):
- raise RuntimeError(
- f"flatten received input {input_val} that is not part "
- "of the TensorRT region!"
- )
+ raise RuntimeError(f"flatten received input {input_val} that is not part " "of the TensorRT region!")
num_dims = len(input_val.shape) + (1 if network.has_implicit_batch_dimension else 0)
- start_dim = get_positive_dim(
- cast(int, kwargs["start_dim"] if "start_dim" in kwargs else 0), num_dims
- )
- end_dim = get_positive_dim(
- cast(int, kwargs["end_dim"] if "end_dim" in kwargs else -1), num_dims
- )
+ start_dim = get_positive_dim(cast(int, kwargs["start_dim"] if "start_dim" in kwargs else 0), num_dims)
+ end_dim = get_positive_dim(cast(int, kwargs["end_dim"] if "end_dim" in kwargs else -1), num_dims)
if network.has_implicit_batch_dimension:
assert start_dim != 0, "Can't flatten batch dimension when it's implicit."
start_dim -= 1
end_dim -= 1
@@ -511,24 +461,18 @@
kwargs: Dict[str, Argument],
name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
input_t = kwargs["input"]
if type(input_t) == torch.nn.Parameter or type(input_t) == torch.Tensor:
- if (
- not has_dynamic_shape(input_t.shape)
- and network.has_implicit_batch_dimension
- ):
+ if not has_dynamic_shape(input_t.shape) and network.has_implicit_batch_dimension:
return torch.Size((IMPLICIT_BATCH_DIM,) + tuple(input_t.shape))
return input_t.shape
# input_val = get_trt_tensor(network, input_t, f"{name}_input_t")
input_val = input_t
if not isinstance(input_val, TRTTensor):
- raise RuntimeError(
- f"size received input {input_val} that is not part "
- "of the TensorRT region!"
- )
+ raise RuntimeError(f"size received input {input_val} that is not part " "of the TensorRT region!")
if not has_dynamic_shape(input_val.shape):
if network.has_implicit_batch_dimension:
return torch.Size((IMPLICIT_BATCH_DIM,) + tuple(input_val.shape))
return torch.Size(input_val.shape)
@@ -547,14 +491,11 @@
name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
input_val = kwargs["input"]
if not isinstance(input_val, TRTTensor):
- raise RuntimeError(
- f"size received input {input_val} that is not part "
- "of the TensorRT region!"
- )
+ raise RuntimeError(f"size received input {input_val} that is not part " "of the TensorRT region!")
if has_dynamic_shape(input_val.shape):
raise RuntimeError(f"numel does not support dynamic shapes.")
numel = np.prod(input_val.shape)
@@ -572,29 +513,20 @@
name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
input_val = kwargs["input"]
if not isinstance(input_val, TRTTensor):
- raise RuntimeError(
- f"BatchNorm2d received input {input_val} that is not part "
- "of the TensorRT region!"
- )
+ raise RuntimeError(f"BatchNorm2d received input {input_val} that is not part " "of the TensorRT region!")
if has_dynamic_shape(input_val.shape):
assert input_val.shape[1] != -1, "Channel dim can't be dynamic for batch norm."
- scale = cast(
- torch.Tensor, to_numpy(cast(torch.Tensor, kwargs["weight"]))
- ) / np.sqrt(
- cast(torch.Tensor, to_numpy(cast(torch.Tensor, kwargs["running_var"])))
- + cast(float, kwargs["eps"])
+ scale = cast(torch.Tensor, to_numpy(cast(torch.Tensor, kwargs["weight"]))) / np.sqrt(
+ cast(torch.Tensor, to_numpy(cast(torch.Tensor, kwargs["running_var"]))) + cast(float, kwargs["eps"])
)
- bias = (
- to_numpy(cast(torch.Tensor, kwargs["bias"]))
- - to_numpy(cast(torch.Tensor, kwargs["running_mean"])) * scale
- )
+ bias = to_numpy(cast(torch.Tensor, kwargs["bias"])) - to_numpy(cast(torch.Tensor, kwargs["running_mean"])) * scale
power = np.ones_like(scale)
# For BatchNorm1d, reshape 1d to 2d
output_shape = input_val.shape
if not network.has_implicit_batch_dimension and len(input_val.shape) < 4:
@@ -628,44 +560,33 @@
@tensorrt_converter(acc_ops.layer_norm)
def acc_ops_layer_norm(network, target, args, kwargs, name):
input_val = kwargs["input"]
if not isinstance(input_val, trt.tensorrt.ITensor):
- raise RuntimeError(
- f"LayerNorm received input {input_val} that is not part "
- "of the TensorRT region!"
- )
+ raise RuntimeError(f"LayerNorm received input {input_val} that is not part " "of the TensorRT region!")
gamma = kwargs["weight"].detach().cpu().float().numpy()
gamma_field = trt.PluginField("gamma", gamma, trt.PluginFieldType.FLOAT32)
beta = kwargs["bias"].detach().cpu().float().numpy()
beta_field = trt.PluginField("beta", beta, trt.PluginFieldType.FLOAT32)
- eps_field = trt.PluginField(
- "eps", np.array([kwargs["eps"]], dtype=np.float32), trt.PluginFieldType.FLOAT32
- )
+ eps_field = trt.PluginField("eps", np.array([kwargs["eps"]], dtype=np.float32), trt.PluginFieldType.FLOAT32)
try:
normalized_shape = np.array(kwargs["normalized_shape"], dtype=np.int32)
except TypeError:
_LOGGER.error("Unable to convert normalized_shape to a field, fall back to []")
normalized_shape = np.array([], dtype=np.int32)
- normalized_shape_filed = trt.PluginField(
- "normalized_shape", normalized_shape, trt.PluginFieldType.INT32
- )
- field_collection = trt.PluginFieldCollection(
- [gamma_field, beta_field, eps_field, normalized_shape_filed]
- )
+ normalized_shape_filed = trt.PluginField("normalized_shape", normalized_shape, trt.PluginFieldType.INT32)
+ field_collection = trt.PluginFieldCollection([gamma_field, beta_field, eps_field, normalized_shape_filed])
try:
if network.has_implicit_batch_dimension:
plugin = get_trt_plugin("layer_norm", field_collection, "1", "fx2trt")
else:
plugin = get_trt_plugin("LayerNormDynamic", field_collection, "1", "fx2trt")
except AssertionError:
- _LOGGER.error(
- "Unable to find layer norm plugin, fall back to TensorRT implementation."
- )
+ _LOGGER.error("Unable to find layer norm plugin, fall back to TensorRT implementation.")
return layer_norm(network, target, args, kwargs, name)
layer = network.add_plugin_v2([input_val], plugin)
layer.name = name
return layer.get_output(0)
@@ -678,14 +599,11 @@
name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
input_val = kwargs["input"]
if not isinstance(input_val, TRTTensor):
- raise RuntimeError(
- f"LayerNorm received input {input_val} that is not part "
- "of the TensorRT region!"
- )
+ raise RuntimeError(f"LayerNorm received input {input_val} that is not part " "of the TensorRT region!")
shape = kwargs["weight"].shape # type: ignore[union-attr]
broadcasted_shape = (1,) * (len(input_val.shape) - len(shape)) + shape
gamma = to_numpy(kwargs["weight"].reshape(*shape)) # type: ignore[union-attr]
beta = to_numpy(kwargs["bias"].reshape(*shape)) # type: ignore[union-attr]
@@ -694,13 +612,11 @@
axes = 0
for d in range(len(shape)):
axes |= 1 << (len(input_val.shape) - d - 1)
# E[x]
- mean_expected_layer = network.add_reduce(
- input_val, trt.ReduceOperation.AVG, axes, keep_dims=True
- )
+ mean_expected_layer = network.add_reduce(input_val, trt.ReduceOperation.AVG, axes, keep_dims=True)
set_layer_name(mean_expected_layer, target, f"{name}_mean_expected")
# X-E[x]
sub_trt = add_binary_elementwise_layer(
network,
@@ -722,13 +638,11 @@
pow_tensor.get_output(0),
trt.ElementWiseOperation.POW,
target,
f"{name}_pow_var",
)
- mean_trt_layer = network.add_reduce(
- pow_var, trt.ReduceOperation.AVG, axes, keep_dims=True
- )
+ mean_trt_layer = network.add_reduce(pow_var, trt.ReduceOperation.AVG, axes, keep_dims=True)
set_layer_name(mean_trt_layer, target, f"{name}_mean")
# Variance + eps
eps_tensor = network.add_constant(
(1,) * len(input_val.shape),
trt.Weights(np.ascontiguousarray([eps], dtype=np.float32)),
@@ -741,13 +655,11 @@
trt.ElementWiseOperation.SUM,
target,
f"{name}_add",
)
# SQRT((Var + eps))
- sqrt_trt = add_unary_layer(
- network, add_trt, trt.UnaryOperation.SQRT, target, f"{name}_sqrt"
- )
+ sqrt_trt = add_unary_layer(network, add_trt, trt.UnaryOperation.SQRT, target, f"{name}_sqrt")
# (x - E[x]) / sqrt((var + eps))
div_trt = add_binary_elementwise_layer(
network,
sub_trt,
sqrt_trt,
@@ -791,14 +703,11 @@
) -> Union[TRTTensor, Sequence[TRTTensor]]:
input_val = kwargs["input"]
input_ranks = len(input_val.shape) + (1 if network.has_implicit_batch_dimension else 0) # type: ignore[union-attr]
if not isinstance(input_val, TRTTensor):
- raise RuntimeError(
- f"softmax received input {input_val} that is not part "
- "of the TensorRT region!"
- )
+ raise RuntimeError(f"softmax received input {input_val} that is not part " "of the TensorRT region!")
# Used to get dim when dim is None. Copied from PyTorch softmax implementation.
def get_softmax_dim(ndim: int) -> int:
if ndim == 0 or ndim == 1 or ndim == 3:
ret = 0
@@ -832,13 +741,11 @@
) -> Union[TRTTensor, Sequence[TRTTensor]]:
input_t = kwargs["input"]
input_val = get_trt_tensor(network, input_t, f"{name}_input")
dims = tuple(cast(Sequence[int], kwargs["dims"]))
- n_input_dims = len(input_val.shape) + (
- 1 if network.has_implicit_batch_dimension else 0
- )
+ n_input_dims = len(input_val.shape) + (1 if network.has_implicit_batch_dimension else 0)
if len(dims) > n_input_dims:
assert not network.has_implicit_batch_dimension
layer = network.add_shuffle(input_val)
layer.name = f"{name}_reshape"
@@ -849,20 +756,16 @@
input_shape_layer.name = f"{name}_input_shape"
preceding_ones = network.add_constant(
(num_preceding_ones,),
np.ascontiguousarray([1] * num_preceding_ones, np.int32),
).get_output(0)
- reshape_layer = network.add_concatenation(
- [preceding_ones, input_shape_layer.get_output(0)]
- )
+ reshape_layer = network.add_concatenation([preceding_ones, input_shape_layer.get_output(0)])
reshape_layer.axis = 0
reshape_layer.name = f"{name}_reshape_dims"
layer.set_input(1, reshape_layer.get_output(0))
else:
- layer.reshape_dims = (1,) * (len(dims) - n_input_dims) + tuple(
- input_val.shape
- )
+ layer.reshape_dims = (1,) * (len(dims) - n_input_dims) + tuple(input_val.shape)
input_val = layer.get_output(0)
else:
dims = (1,) * (n_input_dims - len(dims)) + dims
if network.has_implicit_batch_dimension:
@@ -898,17 +801,15 @@
layer = network.add_slice(input_val, starts, shapes, strides)
layer.mode = trt.SliceMode.WRAP
set_layer_name(layer, target, name)
if has_dynamic_shape(input_val.shape): # type: ignore[union-attr]
- starts_tensor = network.add_constant(
- (len(dims),), np.ascontiguousarray([0] * len(dims), np.int32)
- ).get_output(0)
+ starts_tensor = network.add_constant((len(dims),), np.ascontiguousarray([0] * len(dims), np.int32)).get_output(
+ 0
+ )
if all(isinstance(d, int) for d in dims):
- dims_tensor = network.add_constant(
- (len(dims),), np.ascontiguousarray(dims, np.int32)
- ).get_output(0)
+ dims_tensor = network.add_constant((len(dims),), np.ascontiguousarray(dims, np.int32)).get_output(0)
else:
assert all(isinstance(d, TRTTensor) for d in dims)
concat_dims_layer = network.add_concatenation(inputs=dims)
concat_dims_layer.axis = 0
concat_dims_layer.name = f"{name}_tile_dim"
@@ -969,13 +870,11 @@
name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
input_val = kwargs["input"]
negative_slope = kwargs["negative_slope"]
operation_type = trt.ActivationType.LEAKY_RELU
- return add_activation_layer(
- network, input_val, operation_type, target, name, negative_slope
- )
+ return add_activation_layer(network, input_val, operation_type, target, name, negative_slope)
@tensorrt_converter(acc_ops.elu)
def acc_ops_elu(
network: TRTNetwork,
@@ -1243,51 +1142,40 @@
target: Target,
args: Tuple[Argument, ...],
kwargs: Dict[str, Argument],
name: str,
) -> TRTTensor:
- return add_reduce_layer(
- network, target, args, kwargs, trt.ReduceOperation.SUM, name
- )
+ return add_reduce_layer(network, target, args, kwargs, trt.ReduceOperation.SUM, name)
@tensorrt_converter(acc_ops.prod)
def acc_ops_prod(
network: TRTNetwork,
target: Target,
args: Tuple[Argument, ...],
kwargs: Dict[str, Argument],
name: str,
) -> TRTTensor:
- return add_reduce_layer(
- network, target, args, kwargs, trt.ReduceOperation.PROD, name
- )
+ return add_reduce_layer(network, target, args, kwargs, trt.ReduceOperation.PROD, name)
@tensorrt_converter(acc_ops.mean)
def acc_ops_mean(
network: TRTNetwork,
target: Target,
args: Tuple[Argument, ...],
kwargs: Dict[str, Argument],
name: str,
) -> TRTTensor:
- return add_reduce_layer(
- network, target, args, kwargs, trt.ReduceOperation.AVG, name
- )
+ return add_reduce_layer(network, target, args, kwargs, trt.ReduceOperation.AVG, name)
def add_acc_ops_full_reduce(network, target, args, kwargs, name, reduce_op):
input_val = kwargs["input"]
if not isinstance(input_val, TRTTensor):
- raise RuntimeError(
- f"max received input {input_val} that is not part "
- "of the TensorRT region!"
- )
- assert (
- not network.has_implicit_batch_dimension
- ), "Do not support max over all the elements for implicit batch."
+ raise RuntimeError(f"max received input {input_val} that is not part " "of the TensorRT region!")
+ assert not network.has_implicit_batch_dimension, "Do not support max over all the elements for implicit batch."
dim = range(len(input_val.shape))
layer = network.add_reduce(
input_val,
@@ -1307,25 +1195,21 @@
new_kwargs["largest"] = True
elif reduce_op == trt.ReduceOperation.MIN:
new_kwargs["largest"] = False
new_kwargs["sorted"] = False
- topk_out0, topk_out1 = acc_ops_topk(
- network, target, args, new_kwargs, name + "_topk"
- )
+ topk_out0, topk_out1 = acc_ops_topk(network, target, args, new_kwargs, name + "_topk")
topk_out0.name = f"{name}_topk0"
topk_out1.name = f"{name}_topk1"
if "keepdim" in new_kwargs and new_kwargs["keepdim"]:
return topk_out0, topk_out1
dim = new_kwargs["dim"]
if network.has_implicit_batch_dimension:
- assert (
- dim != 0
- ), "can't reduce on dim == 0 when network has implicit batch dimension"
+ assert dim != 0, "can't reduce on dim == 0 when network has implicit batch dimension"
# we remove the first dim in the shape tuple when it is implicit
dim -= 1
input_val = topk_out0
shape = input_val.shape
@@ -1355,52 +1239,44 @@
target: Target,
args: Tuple[Argument, ...],
kwargs: Dict[str, Argument],
name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
- return add_acc_ops_full_reduce(
- network, target, args, kwargs, name, trt.ReduceOperation.MAX
- )
+ return add_acc_ops_full_reduce(network, target, args, kwargs, name, trt.ReduceOperation.MAX)
@tensorrt_converter(acc_ops.min_full_reduce, no_implicit_batch_dim=True)
def acc_ops_min_full_reduce(
network: TRTNetwork,
target: Target,
args: Tuple[Argument, ...],
kwargs: Dict[str, Argument],
name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
- return add_acc_ops_full_reduce(
- network, target, args, kwargs, name, trt.ReduceOperation.MIN
- )
+ return add_acc_ops_full_reduce(network, target, args, kwargs, name, trt.ReduceOperation.MIN)
@tensorrt_converter(acc_ops.max_dim_reduce)
def acc_ops_max_dim_reduce(
network: TRTNetwork,
target: Target,
args: Tuple[Argument, ...],
kwargs: Dict[str, Argument],
name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
- return add_acc_ops_dim_reduce(
- network, target, args, kwargs, name, trt.ReduceOperation.MAX
- )
+ return add_acc_ops_dim_reduce(network, target, args, kwargs, name, trt.ReduceOperation.MAX)
@tensorrt_converter(acc_ops.min_dim_reduce)
def acc_ops_min_dim_reduce(
network: TRTNetwork,
target: Target,
args: Tuple[Argument, ...],
kwargs: Dict[str, Argument],
name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
- return add_acc_ops_dim_reduce(
- network, target, args, kwargs, name, trt.ReduceOperation.MIN
- )
+ return add_acc_ops_dim_reduce(network, target, args, kwargs, name, trt.ReduceOperation.MIN)
@tensorrt_converter(acc_ops.maximum)
def acc_ops_maximum(
network: TRTNetwork,
@@ -1503,32 +1379,24 @@
args: Tuple[Argument, ...],
kwargs: Dict[str, Argument],
name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
if network.has_implicit_batch_dimension:
- raise RuntimeError(
- "The `logical_and` function should be called with explicit batch dimension."
- )
+ raise RuntimeError("The `logical_and` function should be called with explicit batch dimension.")
input_t = kwargs["input"]
other_t = kwargs["other"]
# we only support both inputs are bool type
if target == acc_ops.bitwise_and:
def check_is_bool(input_t):
if isinstance(input_t, TRTTensor):
- assert (
- input_t.dtype == trt.bool
- ), "We currently do not support input is non-bool"
+ assert input_t.dtype == trt.bool, "We currently do not support input is non-bool"
elif isinstance(input_t, torch.Tensor):
- assert (
- input_t.dtype == torch.bool
- ), "We currently do not support input is non-bool"
+ assert input_t.dtype == torch.bool, "We currently do not support input is non-bool"
else:
- assert isinstance(
- input_t.bool
- ), "We currently do not support input is non-bool"
+ assert isinstance(input_t.bool), "We currently do not support input is non-bool"
check_is_bool(input_t)
check_is_bool(other_t)
input_t = get_trt_tensor(network, input_t, f"{name}_input_t")
@@ -1536,13 +1404,11 @@
if input_t.dtype != trt.bool:
input_t = type_cast(network, target, f"{name}_input", input_t, trt.bool)
if other_t.dtype != trt.bool:
other_t = type_cast(network, target, f"{name}_other", other_t, trt.bool)
- return add_binary_elementwise_layer(
- network, input_t, other_t, trt.ElementWiseOperation.AND, target, name
- )
+ return add_binary_elementwise_layer(network, input_t, other_t, trt.ElementWiseOperation.AND, target, name)
@tensorrt_converter(acc_ops.ne, no_implicit_batch_dim=True)
def acc_ops_ne(
network: TRTNetwork,
@@ -1550,24 +1416,20 @@
args: Tuple[Argument, ...],
kwargs: Dict[str, Argument],
name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
if network.has_implicit_batch_dimension:
- raise RuntimeError(
- "The `ne` function should be called with explicit batch dimension."
- )
+ raise RuntimeError("The `ne` function should be called with explicit batch dimension.")
input_t = kwargs["input"]
other_t = kwargs["other"]
input_t = get_trt_tensor(network, input_t, f"{name}_input_t")
other_t = get_trt_tensor(network, other_t, f"{name}_other_t")
input_t, other_t = dtype_uniform(network, target, name, input_t, other_t)
- eq_t = add_binary_elementwise_layer(
- network, input_t, other_t, trt.ElementWiseOperation.EQUAL, target, name
- )
+ eq_t = add_binary_elementwise_layer(network, input_t, other_t, trt.ElementWiseOperation.EQUAL, target, name)
return add_unary_layer(network, eq_t, trt.UnaryOperation.NOT, target, name)
@tensorrt_converter(acc_ops.eq, no_implicit_batch_dim=True)
@@ -1577,24 +1439,20 @@
args: Tuple[Argument, ...],
kwargs: Dict[str, Argument],
name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
if network.has_implicit_batch_dimension:
- raise RuntimeError(
- "The `eq` function should be called with explicit batch dimension."
- )
+ raise RuntimeError("The `eq` function should be called with explicit batch dimension.")
input_t = kwargs["input"]
other_t = kwargs["other"]
input_t = get_trt_tensor(network, input_t, f"{name}_input_t")
other_t = get_trt_tensor(network, other_t, f"{name}_other_t")
input_t, other_t = dtype_uniform(network, target, name, input_t, other_t)
- return add_binary_elementwise_layer(
- network, input_t, other_t, trt.ElementWiseOperation.EQUAL, target, name
- )
+ return add_binary_elementwise_layer(network, input_t, other_t, trt.ElementWiseOperation.EQUAL, target, name)
@tensorrt_converter(acc_ops.gt, no_implicit_batch_dim=True)
def acc_ops_gt(
network: TRTNetwork,
@@ -1602,24 +1460,20 @@
args: Tuple[Argument, ...],
kwargs: Dict[str, Argument],
name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
if network.has_implicit_batch_dimension:
- raise RuntimeError(
- "The `gt` function should be called with explicit batch dimension."
- )
+ raise RuntimeError("The `gt` function should be called with explicit batch dimension.")
input_t = kwargs["input"]
other_t = kwargs["other"]
input_t = get_trt_tensor(network, input_t, f"{name}_input_t")
other_t = get_trt_tensor(network, other_t, f"{name}_other_t")
input_t, other_t = dtype_uniform(network, target, name, input_t, other_t)
- return add_binary_elementwise_layer(
- network, input_t, other_t, trt.ElementWiseOperation.GREATER, target, name
- )
+ return add_binary_elementwise_layer(network, input_t, other_t, trt.ElementWiseOperation.GREATER, target, name)
@tensorrt_converter(acc_ops.lt, no_implicit_batch_dim=True)
def acc_ops_lt(
network: TRTNetwork,
@@ -1627,24 +1481,20 @@
args: Tuple[Argument, ...],
kwargs: Dict[str, Argument],
name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
if network.has_implicit_batch_dimension:
- raise RuntimeError(
- "The `le` function should be called with explicit batch dimension."
- )
+ raise RuntimeError("The `le` function should be called with explicit batch dimension.")
input_t = kwargs["input"]
other_t = kwargs["other"]
input_t = get_trt_tensor(network, input_t, f"{name}_input_t")
other_t = get_trt_tensor(network, other_t, f"{name}_other_t")
input_t, other_t = dtype_uniform(network, target, name, input_t, other_t)
- return add_binary_elementwise_layer(
- network, input_t, other_t, trt.ElementWiseOperation.LESS, target, name
- )
+ return add_binary_elementwise_layer(network, input_t, other_t, trt.ElementWiseOperation.LESS, target, name)
@tensorrt_converter(acc_ops.logical_or, no_implicit_batch_dim=True)
def acc_ops_logical_or(
network: TRTNetwork,
@@ -1652,13 +1502,11 @@
args: Tuple[Argument, ...],
kwargs: Dict[str, Argument],
name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
if network.has_implicit_batch_dimension:
- raise RuntimeError(
- "The `logical_or` function should be called with explicit batch dimension."
- )
+ raise RuntimeError("The `logical_or` function should be called with explicit batch dimension.")
input_t = kwargs["input"]
other_t = kwargs["other"]
if isinstance(other_t, (torch.Tensor, bool)):
if isinstance(other_t, bool):
@@ -1675,13 +1523,11 @@
layer_o = network.add_identity(other_t)
layer_o.set_output_type(0, trt.bool)
set_layer_name(layer_o, target, f"{name}_other_dtype_change")
other_t = layer_o.get_output(0)
- return add_binary_elementwise_layer(
- network, input_t, other_t, trt.ElementWiseOperation.OR, target, name
- )
+ return add_binary_elementwise_layer(network, input_t, other_t, trt.ElementWiseOperation.OR, target, name)
@tensorrt_converter(acc_ops.logical_xor, no_implicit_batch_dim=True)
def acc_ops_logical_xor(
network: TRTNetwork,
@@ -1689,13 +1535,11 @@
args: Tuple[Argument, ...],
kwargs: Dict[str, Argument],
name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
if network.has_implicit_batch_dimension:
- raise RuntimeError(
- "The `logical_xor` function should be called with explicit batch dimension."
- )
+ raise RuntimeError("The `logical_xor` function should be called with explicit batch dimension.")
input_t = kwargs["input"]
other_t = kwargs["other"]
if isinstance(other_t, (torch.Tensor, bool)):
if isinstance(other_t, bool):
@@ -1712,13 +1556,11 @@
layer_o = network.add_identity(other_t)
layer_o.set_output_type(0, trt.bool)
set_layer_name(layer_o, target, f"{name}_other_dtype_change")
other_t = layer_o.get_output(0)
- return add_binary_elementwise_layer(
- network, input_t, other_t, trt.ElementWiseOperation.XOR, target, name
- )
+ return add_binary_elementwise_layer(network, input_t, other_t, trt.ElementWiseOperation.XOR, target, name)
# T113156424 Have some accuracy problems in hf_T5.
# [TRT] [W] Weights [name=isinf_1_inf_t]: Converted FP32 value in weights (either FP32 infinity or FP32 value outside FP16 range) to corresponding FP16 infinity. If this is not the desired behavior, please modify the weights or retrain with regularization to reduce the magnitude of the weights.
# @tensorrt_converter(acc_ops.isinf)
@@ -1764,26 +1606,19 @@
kwargs: Dict[str, Argument],
name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
input_t = kwargs["input"]
if not isinstance(input_t, TRTTensor):
- raise RuntimeError(
- f"isinf received input {input_t} that is not part "
- "of the TensorRT region!"
- )
+ raise RuntimeError(f"isinf received input {input_t} that is not part " "of the TensorRT region!")
if input_t.dtype in (trt.float32, trt.float16, trt.int32):
- comp_t = torch.zeros(tuple([*input_t.shape])).to(
- torch_dtype_from_trt(input_t.dtype)
- )
+ comp_t = torch.zeros(tuple([*input_t.shape])).to(torch_dtype_from_trt(input_t.dtype))
comp_t = get_trt_tensor(network, comp_t, f"{name}_comp_t")
kwargs_new = {"input": input_t, "other": comp_t}
eq_output = acc_ops_eq(network, target, None, kwargs_new, name + "_eq")
kwargs_new = {"input": eq_output}
- not_output = acc_ops_logical_not(
- network, target, None, kwargs_new, name + "_not"
- )
+ not_output = acc_ops_logical_not(network, target, None, kwargs_new, name + "_not")
else:
not_output = input_t
# cast bool result to int
int_output = type_cast(network, target, f"{name}_cast_int", not_output, trt.int32)
# sum
@@ -1809,13 +1644,11 @@
args: Tuple[Argument, ...],
kwargs: Dict[str, Argument],
name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
# NOTE: TRT doesnt currently implement fmod so we need multiple operations to perform it
- trunc_div_value = trunc_div(
- kwargs["input"], kwargs["other"], network, target, name + "_trunc_div"
- )
+ trunc_div_value = trunc_div(kwargs["input"], kwargs["other"], network, target, name + "_trunc_div")
prod_value = add_binary_elementwise_layer(
network,
trunc_div_value,
kwargs["other"],
trt.ElementWiseOperation.PROD,
@@ -1907,14 +1740,11 @@
kwargs: Dict[str, Argument],
name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
input_trt = kwargs["input"]
if not isinstance(input_trt, TRTTensor):
- raise RuntimeError(
- f"Max_pool1d received input {input_trt} that is not part "
- "of the TensorRT region!"
- )
+ raise RuntimeError(f"Max_pool1d received input {input_trt} that is not part " "of the TensorRT region!")
# adds unsqueeze layer -> max pool 2d -> squeeze layer to emulate max pool 1d.
unsqueeze_layer = network.add_shuffle(input=input_trt)
unsqueeze_layer.reshape_dims = tuple([*input_trt.shape, 1])
set_layer_name(unsqueeze_layer, target, name + "_unsqueeze")
@@ -1929,25 +1759,16 @@
ceil_mode = kwargs["ceil_mode"]
if len(stride) == 0 or stride[0] == None:
stride = kernel_size
- if any(
- [
- not isinstance(param, int)
- for param in [kernel_size[0], stride[0], padding[0], dilation[0]]
- ]
- ):
- raise RuntimeError(
- f"Parameters kernel_size, stride, padding, and dilation should be of type int."
- )
+ if any([not isinstance(param, int) for param in [kernel_size[0], stride[0], padding[0], dilation[0]]]):
+ raise RuntimeError(f"Parameters kernel_size, stride, padding, and dilation should be of type int.")
if dilation[0] != 1:
raise RuntimeError(f"Only support dilation=1 for maxpool, but got {dilation}")
- max_pooling_layer = network.add_pooling(
- input=input_trt, type=trt.PoolingType.MAX, window_size=(kernel_size[0], 1)
- )
+ max_pooling_layer = network.add_pooling(input=input_trt, type=trt.PoolingType.MAX, window_size=(kernel_size[0], 1))
max_pooling_layer.stride_nd = stride + (1,)
max_pooling_layer.padding_nd = padding + (0,)
set_layer_name(max_pooling_layer, target, name)
if ceil_mode:
@@ -1969,14 +1790,11 @@
name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
input_val = kwargs["input"]
if not isinstance(input_val, TRTTensor):
- raise RuntimeError(
- f"MaxPool2d received input {input_val} that is not part "
- "of the TensorRT region!"
- )
+ raise RuntimeError(f"MaxPool2d received input {input_val} that is not part " "of the TensorRT region!")
extend_len = 2 if target == acc_ops.max_pool2d else 3
kernel_size = extend_attr_to_tuple(kwargs["kernel_size"], extend_len)
stride = extend_attr_to_tuple(kwargs["stride"], extend_len)
padding = extend_attr_to_tuple(kwargs["padding"], extend_len)
dilation = extend_attr_to_tuple(kwargs["dilation"], extend_len)
@@ -1985,17 +1803,13 @@
if len(stride) == 0 or stride[0] == None:
stride = kernel_size
ones = (1,) * extend_len
if dilation != ones:
- raise RuntimeError(
- f"Only support dilation=(1, 1) for maxpool, but got {dilation}"
- )
-
- layer = network.add_pooling_nd(
- input=input_val, type=trt.PoolingType.MAX, window_size=kernel_size
- )
+ raise RuntimeError(f"Only support dilation=(1, 1) for maxpool, but got {dilation}")
+
+ layer = network.add_pooling_nd(input=input_val, type=trt.PoolingType.MAX, window_size=kernel_size)
layer.stride_nd = stride
layer.padding_nd = padding
set_layer_name(layer, target, name)
if ceil_mode:
@@ -2013,23 +1827,18 @@
name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
input_val = kwargs["input"]
if not isinstance(input_val, TRTTensor):
- raise RuntimeError(
- f"squeeze received input {input_val} that is not part "
- "of the TensorRT region!"
- )
+ raise RuntimeError(f"squeeze received input {input_val} that is not part " "of the TensorRT region!")
dim = cast(Optional[int], kwargs["dim"] if "dim" in kwargs else None)
# Squeeze with dim=None would only work in explicit batch dim mode without any dynamic
# dim, which is a very rare case. For now we just claim not supporting dim=None.
assert dim is not None, "We don't support dim=None right now for squeeze."
- dim = get_positive_dim(
- dim, len(input_val.shape) + (1 if network.has_implicit_batch_dimension else 0)
- )
+ dim = get_positive_dim(dim, len(input_val.shape) + (1 if network.has_implicit_batch_dimension else 0))
if network.has_implicit_batch_dimension:
assert dim != 0, "We don't support squeeze batch dim when it's implicit."
dim -= 1
assert input_val.shape[dim] != -1, "We don't support squeeze dynamic dim."
@@ -2176,35 +1985,26 @@
name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
input_t = kwargs["input"]
input_val = get_trt_tensor(network, input_t, f"{name}_input_t")
if not isinstance(input_val, TRTTensor):
- raise RuntimeError(
- f"unsqueeze received input {input_val} that is not part "
- "of the TensorRT region!"
- )
+ raise RuntimeError(f"unsqueeze received input {input_val} that is not part " "of the TensorRT region!")
dim = cast(int, kwargs["dim"])
input_shape = input_val.shape
- input_shape_size = (
- len(input_val.shape) + 1
- if network.has_implicit_batch_dimension
- else len(input_val.shape)
- )
+ input_shape_size = len(input_val.shape) + 1 if network.has_implicit_batch_dimension else len(input_val.shape)
dim = get_positive_dim(dim, input_shape_size + 1)
if network.has_implicit_batch_dimension:
assert dim != 0
dim -= 1
assert (
len(get_dynamic_dims(input_val.shape)) <= 1
), "Currently we don't support unsqueeze with more than one dynamic dims."
layer = network.add_shuffle(input_val)
- layer.reshape_dims = (
- tuple(input_val.shape)[:dim] + (1,) + tuple(input_val.shape)[dim:]
- )
+ layer.reshape_dims = tuple(input_val.shape)[:dim] + (1,) + tuple(input_val.shape)[dim:]
set_layer_name(layer, target, name)
return layer.get_output(0)
@tensorrt_converter(acc_ops.topk)
@@ -2216,14 +2016,11 @@
name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
input_val = kwargs["input"]
if not isinstance(input_val, TRTTensor):
- raise RuntimeError(
- f"topk received input {input_val} that is not part "
- "of the TensorRT region!"
- )
+ raise RuntimeError(f"topk received input {input_val} that is not part " "of the TensorRT region!")
if kwargs["sorted"] and kwargs["k"] != 1:
raise RuntimeError("Currently we don't support sorted=True in topk.")
if not network.has_implicit_batch_dimension and len(input_val.shape) <= 1:
@@ -2253,40 +2050,28 @@
name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
input_val = kwargs["input"]
if not isinstance(input_val, TRTTensor):
- raise RuntimeError(
- f"AdaptiveAvgPool2d received input {input_val} that is not part "
- "of the TensorRT region!"
- )
+ raise RuntimeError(f"AdaptiveAvgPool2d received input {input_val} that is not part " "of the TensorRT region!")
extend_len = 2 if target == acc_ops.adaptive_avg_pool2d else 3
assert all(
input_val.shape[-(i + 1)] != -1 for i in range(extend_len)
), "AdaptiveAvgPool2d and AdaptiveAvgPool3d currently doesn't support dynamic shapes for last two dims."
- output_size = cast(
- Sequence[int], extend_attr_to_tuple(kwargs["output_size"], extend_len)
- )
+ output_size = cast(Sequence[int], extend_attr_to_tuple(kwargs["output_size"], extend_len))
for input_dim, output_dim in zip(input_val.shape[-extend_len:], output_size):
if input_dim % output_dim != 0:
raise RuntimeError(
"For AdaptiveAvgPool, input dim has to be integer multiple of output dim."
f"Got input dim {input_dim}, output dim {output_dim}"
)
- stride = tuple(
- input_val.shape[-extend_len + i] // output_size[i] for i in range(extend_len)
- )
- kernel_size = tuple(
- input_val.shape[-extend_len + i] - (output_size[i] - 1) * stride[i]
- for i in range(extend_len)
- )
- layer = network.add_pooling_nd(
- input=input_val, type=trt.PoolingType.AVERAGE, window_size=kernel_size
- )
+ stride = tuple(input_val.shape[-extend_len + i] // output_size[i] for i in range(extend_len))
+ kernel_size = tuple(input_val.shape[-extend_len + i] - (output_size[i] - 1) * stride[i] for i in range(extend_len))
+ layer = network.add_pooling_nd(input=input_val, type=trt.PoolingType.AVERAGE, window_size=kernel_size)
layer.stride_nd = stride
set_layer_name(layer, target, name)
return layer.get_output(0)
@@ -2300,14 +2085,11 @@
name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
input_val = kwargs["input"]
if not isinstance(input_val, TRTTensor):
- raise RuntimeError(
- f"AvgPool1d received input {input_val} that is not part "
- "of the TensorRT region!"
- )
+ raise RuntimeError(f"AvgPool1d received input {input_val} that is not part " "of the TensorRT region!")
kernel_size = extend_attr_to_tuple(kwargs["kernel_size"], 1)
stride = extend_attr_to_tuple(kwargs["stride"], 1)
padding = extend_attr_to_tuple(kwargs["padding"], 1)
ceil_mode = kwargs["ceil_mode"]
@@ -2319,13 +2101,11 @@
shuffle_layer = network.add_shuffle(input_val)
shuffle_layer.reshape_dims = tuple(input_val.shape) + (1,)
set_layer_name(shuffle_layer, target, name + "_shuffle1")
shuffle_out = shuffle_layer.get_output(0)
- layer = network.add_pooling_nd(
- input=shuffle_out, type=trt.PoolingType.AVERAGE, window_size=(kernel_size[0], 1)
- )
+ layer = network.add_pooling_nd(input=shuffle_out, type=trt.PoolingType.AVERAGE, window_size=(kernel_size[0], 1))
layer.stride_nd = stride + (1,)
layer.padding_nd = padding + (0,)
layer.average_count_excludes_padding = False if count_include_pad else True
set_layer_name(layer, target, name)
@@ -2349,14 +2129,11 @@
name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
input_val = kwargs["input"]
if not isinstance(input_val, TRTTensor):
- raise RuntimeError(
- f"AvgPool2d received input {input_val} that is not part "
- "of the TensorRT region!"
- )
+ raise RuntimeError(f"AvgPool2d received input {input_val} that is not part " "of the TensorRT region!")
kernel_size = extend_attr_to_tuple(kwargs["kernel_size"], 2)
stride = extend_attr_to_tuple(kwargs["stride"], 2)
padding = extend_attr_to_tuple(kwargs["padding"], 2)
ceil_mode = kwargs["ceil_mode"]
@@ -2367,13 +2144,11 @@
stride = kernel_size
if divisor_override:
raise RuntimeError("TensorRT does not support divisor_override.")
- layer = network.add_pooling(
- input=input_val, type=trt.PoolingType.AVERAGE, window_size=kernel_size
- )
+ layer = network.add_pooling(input=input_val, type=trt.PoolingType.AVERAGE, window_size=kernel_size)
layer.stride = stride
layer.padding = padding
layer.average_count_excludes_padding = False if count_include_pad else True
set_layer_name(layer, target, name)
@@ -2433,23 +2208,18 @@
name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
input_val = kwargs["input"]
if not isinstance(input_val, TRTTensor):
- raise RuntimeError(
- f"slice_tensor received input {input_val} that is not part "
- "of the TensorRT region!"
- )
+ raise RuntimeError(f"slice_tensor received input {input_val} that is not part " "of the TensorRT region!")
ranks = len(input_val.shape) + (1 if network.has_implicit_batch_dimension else 0)
dim = get_positive_dim(cast(int, kwargs["dim"]), ranks)
dynamic_shape = has_dynamic_shape(input_val.shape)
if network.has_implicit_batch_dimension:
if dim == 0:
- raise RuntimeError(
- f"We do not support slice_tensor at batch dim when it's implicit, got {dim}!"
- )
+ raise RuntimeError(f"We do not support slice_tensor at batch dim when it's implicit, got {dim}!")
dim = dim - 1
else:
if dynamic_shape:
# Check whether slice target dim is dynamic shape dim
assert input_val.shape[dim] != -1, "Can't chunk on dynamic shape dimension!"
@@ -2463,13 +2233,11 @@
stride[dim] = step_int
output_shape = list(input_val.shape)
output_shape[dim] = (stop_int - start_int) // step_int
if dynamic_shape > 0:
- output_shape = get_shape_with_dynamic_shape(
- network, output_shape, input_val, target, name
- )
+ output_shape = get_shape_with_dynamic_shape(network, output_shape, input_val, target, name)
layer = network.add_slice(
input_val,
start=start,
shape=[] if dynamic_shape else output_shape,
stride=stride,
@@ -2502,13 +2270,11 @@
shape = [input_val.shape[i] if shape[i] == -1 else shape[i] for i in range(ranks)]
inshape = tuple(input_val.shape)
shape = tuple(shape)
start = tuple([0] * ranks)
- stride = tuple(
- [int(i == o) for i, o in zip(inshape, shape)]
- ) # stride == 1 if dimensions match, 0 otherwise
+ stride = tuple([int(i == o) for i, o in zip(inshape, shape)]) # stride == 1 if dimensions match, 0 otherwise
layer = network.add_slice(input_val, start=start, shape=shape, stride=stride)
set_layer_name(layer, target, name)
return layer.get_output(0)
@@ -2615,13 +2381,11 @@
) -> Union[TRTTensor, Sequence[TRTTensor]]:
input_t = kwargs["input"]
mask_t = kwargs["mask"]
value_t = kwargs["value"]
if network.has_implicit_batch_dimension:
- raise RuntimeError(
- "We don't support masked_fill with implicit batch dimension due to select layer!"
- )
+ raise RuntimeError("We don't support masked_fill with implicit batch dimension due to select layer!")
shape = list(input_t.shape)
mask_shape = list(mask_t.shape)
assert type(value_t) in (
@@ -2674,14 +2438,11 @@
name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
input_val = kwargs["input"]
if not isinstance(input_val, TRTTensor):
- raise RuntimeError(
- f"split received input {input_val} that is not part "
- "of the TensorRT region!"
- )
+ raise RuntimeError(f"split received input {input_val} that is not part " "of the TensorRT region!")
dim = cast(int, kwargs["dim"])
dynamic_shape = has_dynamic_shape(input_val.shape)
if network.has_implicit_batch_dimension:
assert dim != 0, "Can't split on batch dim when it's implicit!"
@@ -2695,28 +2456,22 @@
start = [0] * len(input_val.shape)
stride = [1] * len(start)
offset = 0
num_splits = (input_val.shape[dim] + split_size - 1) // split_size
if num_splits < 1:
- raise RuntimeError(
- f"Invalid split: {input_val.shape[dim]} with split_size={split_size}"
- )
+ raise RuntimeError(f"Invalid split: {input_val.shape[dim]} with split_size={split_size}")
max_offset = input_val.shape[dim]
# add slice layers
output = []
for i in range(num_splits):
shape = list(input_val.shape)
shape[dim] = min(split_size, cast(int, max_offset - offset))
start[dim] = offset
if dynamic_shape:
- shape = get_shape_with_dynamic_shape(
- network, shape, input_val, target, f"{name}_shape_{i}"
- )
- layer = network.add_slice(
- input_val, start=start, shape=[] if dynamic_shape else shape, stride=stride
- )
+ shape = get_shape_with_dynamic_shape(network, shape, input_val, target, f"{name}_shape_{i}")
+ layer = network.add_slice(input_val, start=start, shape=[] if dynamic_shape else shape, stride=stride)
if dynamic_shape:
layer.set_input(2, shape)
offset += split_size
set_layer_name(layer, target, f"{name}_{i}")
output.append(layer.get_output(0))
@@ -2732,19 +2487,15 @@
name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
input_val = kwargs["input"]
if not isinstance(input_val, TRTTensor):
- raise RuntimeError(
- f"Linear received input {input_val} that is not part "
- "of the TensorRT region!"
- )
+ raise RuntimeError(f"Linear received input {input_val} that is not part " "of the TensorRT region!")
dynamic_dims = get_dynamic_dims(input_val.shape)
assert len(dynamic_dims) < 2 and input_val.shape[-1] != -1, (
- "Currently we only support one dynmaic "
- "dim for linear and it can't be the last dim."
+ "Currently we only support one dynmaic " "dim for linear and it can't be the last dim."
)
if isinstance(kwargs["weight"], torch.Tensor):
weight = get_trt_tensor(network, kwargs["weight"].t(), f"{name}_weight")
weight_op = trt.MatrixOperation.NONE
@@ -2760,13 +2511,11 @@
preset_diff -= 1
input_op = trt.MatrixOperation.VECTOR
else:
input_op = trt.MatrixOperation.NONE
- input_val, weight = broadcast(
- network, input_val, weight, f"{name}_input", f"{name}_weight", preset_diff
- )
+ input_val, weight = broadcast(network, input_val, weight, f"{name}_input", f"{name}_weight", preset_diff)
matmul_layer = network.add_matrix_multiply(input_val, input_op, weight, weight_op)
set_layer_name(matmul_layer, target, f"{name}_matmul")
res = matmul_layer.get_output(0)
if kwargs["bias"] is not None:
@@ -2782,16 +2531,11 @@
return res
def add_clamp(network, input, val, op):
acc_ops_clamp_shape = (1,) * len(input.shape) # broadcast all dimensions
- acc_ops_clamp_tensor = (
- val
- * torch.ones(acc_ops_clamp_shape, dtype=torch_dtype_from_trt(input.dtype))
- .cpu()
- .numpy()
- )
+ acc_ops_clamp_tensor = val * torch.ones(acc_ops_clamp_shape, dtype=torch_dtype_from_trt(input.dtype)).cpu().numpy()
acc_ops_clamp_trt = network.add_constant(acc_ops_clamp_shape, acc_ops_clamp_tensor)
layer = network.add_elementwise(input, acc_ops_clamp_trt.get_output(0), op)
return layer
@@ -2807,25 +2551,18 @@
input_val = kwargs["input"]
min_val = kwargs["min"]
max_val = kwargs["max"]
if not isinstance(input_val, TRTTensor):
- raise RuntimeError(
- f"Clamp received input {input_val} that is not part "
- "of the TensorRT region!"
- )
+ raise RuntimeError(f"Clamp received input {input_val} that is not part " "of the TensorRT region!")
if min_val is not None:
- clamp_min_layer = add_clamp(
- network, input_val, min_val, trt.ElementWiseOperation.MAX
- )
+ clamp_min_layer = add_clamp(network, input_val, min_val, trt.ElementWiseOperation.MAX)
set_layer_name(clamp_min_layer, target, f"{name}_clamp_min")
input_val = clamp_min_layer.get_output(0)
if max_val is not None:
- clamp_max_layer = add_clamp(
- network, input_val, max_val, trt.ElementWiseOperation.MIN
- )
+ clamp_max_layer = add_clamp(network, input_val, max_val, trt.ElementWiseOperation.MIN)
set_layer_name(clamp_max_layer, target, f"{name}_clamp_max")
input_val = clamp_max_layer.get_output(0)
return input_val
@@ -2883,30 +2620,22 @@
def slice_to_trt_params(py_slice, dim_size):
"""
Convert python slice to TensorRT slice layer parameters.
"""
- start = (
- get_positive_dim(py_slice.start, dim_size) if py_slice.start != None else 0
- )
+ start = get_positive_dim(py_slice.start, dim_size) if py_slice.start != None else 0
stride = py_slice.step if py_slice.step != None else 1
- stop = (
- get_positive_dim(py_slice.stop, dim_size)
- if py_slice.stop != None
- else dim_size
- )
+ stop = get_positive_dim(py_slice.stop, dim_size) if py_slice.stop != None else dim_size
size = math.ceil((stop - start) * 1.0 / stride)
return start, size, stride
if network.has_implicit_batch_dimension:
# Raise an error if it's trying to subscript batch dimension unless it's
# slice(None, None, None).
batch_subscript = slices[0]
if batch_subscript not in [slice(None, None, None), slice(0, None, None)]:
- raise RuntimeError(
- f"{name}: Can't subscript batch dimension when it's implicit. Got {slices}"
- )
+ raise RuntimeError(f"{name}: Can't subscript batch dimension when it's implicit. Got {slices}")
# Remove batch_dim subscript
slices = slices[1:]
# Replace ellipsis with expanded slices.
@@ -2995,13 +2724,11 @@
) -> Union[TRTTensor, Sequence[TRTTensor]]:
tensors = kwargs["tensors"]
dim = kwargs["dim"]
if any(not isinstance(t, TRTTensor) for t in tensors): # type: ignore[union-attr]
- raise RuntimeError(
- f"cat received inputs {tensors} that is not part " "of the TensorRT region!"
- )
+ raise RuntimeError(f"cat received inputs {tensors} that is not part " "of the TensorRT region!")
layer = network.add_concatenation(inputs=tensors)
if dim < 0:
if network.has_implicit_batch_dimension:
dim = len(tensors[0].shape) + 1 + dim
else:
@@ -3023,13 +2750,11 @@
input_val = get_trt_tensor(network, kwargs["input"], f"{name}_input")
other_val = get_trt_tensor(network, kwargs["other"], f"{name}_other")
for i in [input_val, other_val]:
if not isinstance(i, TRTTensor):
- raise RuntimeError(
- f"matmul received input {i} that is not part of the TensorRT region!"
- )
+ raise RuntimeError(f"matmul received input {i} that is not part of the TensorRT region!")
input_matrix_op = other_matrix_op = trt.MatrixOperation.NONE
preset_diff = 0
if len(input_val.shape) == 1:
@@ -3038,16 +2763,12 @@
if len(other_val.shape) == 1:
preset_diff += 1
other_matrix_op = trt.MatrixOperation.VECTOR
- input_val, other_val = broadcast(
- network, input_val, other_val, f"{name}_input", f"{name}_other", preset_diff
- )
- layer = network.add_matrix_multiply(
- input_val, input_matrix_op, other_val, other_matrix_op
- )
+ input_val, other_val = broadcast(network, input_val, other_val, f"{name}_input", f"{name}_other", preset_diff)
+ layer = network.add_matrix_multiply(input_val, input_matrix_op, other_val, other_matrix_op)
set_layer_name(layer, target, name)
return layer.get_output(0)
@tensorrt_converter(acc_ops.hardsigmoid)
@@ -3059,14 +2780,11 @@
name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
input_val = kwargs["input"]
if not isinstance(input_val, TRTTensor):
- raise RuntimeError(
- f"Hard sigmoid received input {input_val} that is not part "
- "of the TensorRT region!"
- )
+ raise RuntimeError(f"Hard sigmoid received input {input_val} that is not part " "of the TensorRT region!")
return add_activation_layer(
network,
input_val,
trt.ActivationType.HARD_SIGMOID,
@@ -3086,18 +2804,13 @@
name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
input_val = kwargs["input"]
if not isinstance(input_val, TRTTensor):
- raise RuntimeError(
- f"Sigmoid received input {input_val} that is not part "
- "of the TensorRT region!"
- )
-
- return add_activation_layer(
- network, input_val, trt.ActivationType.SIGMOID, target, name
- )
+ raise RuntimeError(f"Sigmoid received input {input_val} that is not part " "of the TensorRT region!")
+
+ return add_activation_layer(network, input_val, trt.ActivationType.SIGMOID, target, name)
@tensorrt_converter(acc_ops.permute)
def acc_ops_permute(
network: TRTNetwork,
@@ -3113,14 +2826,11 @@
else:
index = kwargs["permutation"]
permutation = [get_positive_dim(i, ranks) for i in cast(Sequence[int], index)]
if not isinstance(input_val, TRTTensor):
- raise RuntimeError(
- f"permute received input {input_val} that is not part "
- "of the TensorRT region!"
- )
+ raise RuntimeError(f"permute received input {input_val} that is not part " "of the TensorRT region!")
if network.has_implicit_batch_dimension:
assert permutation[0] == 0, "Can't permute batch dimension when it's implicit."
permutation = [i - 1 for i in permutation[1:]]
@@ -3139,14 +2849,11 @@
name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
input_val = get_trt_tensor(network, kwargs["input"], f"{name}_input")
if not isinstance(input_val, TRTTensor):
- raise RuntimeError(
- f"{name} received input {input_val} that is not part "
- "of the TensorRT region!"
- )
+ raise RuntimeError(f"{name} received input {input_val} that is not part " "of the TensorRT region!")
qparams = kwargs["acc_out_ty"].qparams # type: ignore[misc]
q_scale = qparams["scale"]
q_zero_point = qparams["zero_point"]
dtype = kwargs["acc_out_ty"].dtype # type: ignore[misc]
@@ -3157,13 +2864,11 @@
)
if q_zero_point != 0:
raise RuntimeError(f"Only support zero_point == 0, get {q_zero_point}")
- scale_layer = network.add_constant(
- (1,), trt.Weights(np.ascontiguousarray([float(q_scale)], dtype=np.float32))
- )
+ scale_layer = network.add_constant((1,), trt.Weights(np.ascontiguousarray([float(q_scale)], dtype=np.float32)))
scale_layer.name = input_val.name + ".per_tensor_quant.scale"
scale = scale_layer.get_output(0)
# assert trt.__version__ > "8.0", "Explicit quantize op is only supported in "
# "TensorRT 8.0 or above, current TensorRT version:" + trt.__version__
layer = network.add_quantize(input=input_val, scale=scale)
@@ -3181,14 +2886,11 @@
name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
input_val = get_trt_tensor(network, kwargs["input"], f"{name}_input")
if not isinstance(input_val, TRTTensor):
- raise RuntimeError(
- f"{name} received input {input_val} that is not part "
- "of the TensorRT region!"
- )
+ raise RuntimeError(f"{name} received input {input_val} that is not part " "of the TensorRT region!")
qparams = kwargs["acc_out_ty"].qparams # type: ignore[misc]
q_per_channel_scales = qparams["scale"]
q_per_channel_zero_points = qparams["zero_point"]
q_per_channel_axis = qparams["axis"]
@@ -3201,17 +2903,13 @@
# Make sure zero_points are all 0 because only symmetric quantization
# is supported in TensorRT
if not torch.equal(
q_per_channel_zero_points,
- torch.zeros(
- q_per_channel_zero_points.shape, dtype=q_per_channel_zero_points.dtype
- ),
+ torch.zeros(q_per_channel_zero_points.shape, dtype=q_per_channel_zero_points.dtype),
):
- raise RuntimeError(
- f"Only support zero_point == 0, get {q_per_channel_zero_points}"
- )
+ raise RuntimeError(f"Only support zero_point == 0, get {q_per_channel_zero_points}")
if not torch.all(torch.ge(q_per_channel_scales, 0)):
raise RuntimeError(f"All scale values must be >= 0, get {q_per_channel_scales}")
scale_layer = network.add_constant(
@@ -3238,14 +2936,11 @@
) -> Union[TRTTensor, Sequence[TRTTensor]]:
input_val = kwargs["input"]
input_val_tensor_meta = kwargs["_itensor_to_tensor_meta"][input_val] # type: ignore[index]
if not isinstance(input_val, TRTTensor):
- raise RuntimeError(
- f"{name} received input {input_val} that is not part "
- "of the TensorRT region!"
- )
+ raise RuntimeError(f"{name} received input {input_val} that is not part " "of the TensorRT region!")
qparams = input_val_tensor_meta.qparams # type: ignore[misc]
qscheme = qparams["qscheme"]
if qscheme == torch.per_tensor_affine:
q_scale = qparams["scale"]
@@ -3256,30 +2951,25 @@
raise RuntimeError(f"Only support zero_point == 0, get {q_zero_point}")
elif qscheme == torch.per_channel_affine:
q_scale = qparams["scale"]
q_zero_point = qparams["zero_point"]
q_axis = qparams["axis"]
- assert isinstance(
- q_scale, immutable_list
- ), "expected q_scale to be immutable_list got {}".format(type(q_scale))
+ assert isinstance(q_scale, immutable_list), "expected q_scale to be immutable_list got {}".format(type(q_scale))
scale_shape = (len(q_scale),)
if any(x != 0 for x in q_zero_point):
raise RuntimeError(f"Only support zero_point == 0, get {q_zero_point}")
else:
raise RuntimeError("Unsupported qscheme in dequantize: {qscheme}")
dtype = input_val_tensor_meta.dtype # type: ignore[misc]
if dtype not in (torch.quint8, torch.qint8, torch.qint32):
raise RuntimeError(
- "Only support (torch.quint8, torch.qint8, torch.qint32) "
- f"quantized type in dequantize, get {dtype}."
+ "Only support (torch.quint8, torch.qint8, torch.qint32) " f"quantized type in dequantize, get {dtype}."
)
- scale_layer = network.add_constant(
- scale_shape, trt.Weights(np.ascontiguousarray(q_scale, dtype=np.float32))
- )
+ scale_layer = network.add_constant(scale_shape, trt.Weights(np.ascontiguousarray(q_scale, dtype=np.float32)))
scale_layer.name = input_val.name + ".dequant.scale"
scale = scale_layer.get_output(0)
# assert trt.__version__ > "8.0", "Explicit dequantize op is only supported in "
# "TensorRT 8.0 or above, current TensorRT version:" + trt.__version__
layer = network.add_dequantize(input=input_val, scale=scale)
@@ -3296,24 +2986,17 @@
kwargs: Dict[str, Argument],
name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
input_val = kwargs["input"]
if not isinstance(input_val, TRTTensor):
- raise RuntimeError(
- f"GELU received input {input_val} that is not part "
- "of the TensorRT region!"
- )
+ raise RuntimeError(f"GELU received input {input_val} that is not part " "of the TensorRT region!")
if network.has_implicit_batch_dimension:
- raise RuntimeError(
- "GeLU converter currently doesn't support implicit batch dimension"
- )
+ raise RuntimeError("GeLU converter currently doesn't support implicit batch dimension")
plugin_name = "CustomGeluPluginDynamic"
# type_id 0 for float32, 1 for float16
- type_id = trt.PluginField(
- "type_id", np.array(0, dtype=np.int32), trt.PluginFieldType.INT32
- )
+ type_id = trt.PluginField("type_id", np.array(0, dtype=np.int32), trt.PluginFieldType.INT32)
field_collection = TRTPluginFieldCollection([type_id])
plugin_version = "1"
plugin = get_trt_plugin(plugin_name, field_collection, plugin_version)
@@ -3334,14 +3017,11 @@
chunks = cast(int, kwargs["chunks"])
dim = cast(int, kwargs["dim"])
input_dim_size = len(input_val.shape) # type: ignore[union-attr]
if not isinstance(input_val, TRTTensor):
- raise RuntimeError(
- f"chunk received input {input_val} that is not part "
- "of the TensorRT region!"
- )
+ raise RuntimeError(f"chunk received input {input_val} that is not part " "of the TensorRT region!")
dynamic_shape = has_dynamic_shape(input_val.shape)
if network.has_implicit_batch_dimension:
input_dim_size += 1
dim = get_positive_dim(dim, input_dim_size)
@@ -3371,17 +3051,13 @@
output = []
for i in range(chunks):
shape = list(input_val.shape)
shape[dim] = min(split_size, max_offset - offset)
if dynamic_shape:
- shape = get_shape_with_dynamic_shape(
- network, shape, input_val, target, f"{name}_{i}"
- )
+ shape = get_shape_with_dynamic_shape(network, shape, input_val, target, f"{name}_{i}")
start[dim] = offset
- layer = network.add_slice(
- input_val, start=start, shape=[] if dynamic_shape else shape, stride=stride
- )
+ layer = network.add_slice(input_val, start=start, shape=[] if dynamic_shape else shape, stride=stride)
if dynamic_shape:
layer.set_input(2, shape)
offset += split_size
set_layer_name(layer, target, f"{name}_{i}")
output.append(layer.get_output(0))
@@ -3400,18 +3076,13 @@
dim = cast(int, kwargs["dim"])
input_shape = input_val.shape # type: ignore[union-attr]
input_dim_size = len(input_val.shape) # type: ignore[union-attr]
if not isinstance(input_val, TRTTensor):
- raise RuntimeError(
- f"cumsum received input {input_val} that is not part "
- "of the TensorRT region!"
- )
+ raise RuntimeError(f"cumsum received input {input_val} that is not part " "of the TensorRT region!")
if network.has_implicit_batch_dimension:
- raise RuntimeError(
- "cumsum converter currently doesn't support implicit batch dimension"
- )
+ raise RuntimeError("cumsum converter currently doesn't support implicit batch dimension")
dim = get_positive_dim(dim, input_dim_size)
loop = network.add_loop()
trip_limit = None
if input_shape[dim] > 0:
axis = torch.tensor(input_shape[dim], dtype=torch.int32)
@@ -3427,13 +3098,11 @@
loop.add_trip_limit(trip_limit, trt.TripLimit(0))
iterator = loop.add_iterator(input_val, dim, False)
data = iterator.get_output(0)
new_dims = tuple(data.shape)
zero_tensor = torch.zeros(new_dims, dtype=trt_dtype_to_torch_dtype(input_val.dtype))
- zero_tensor = network.add_constant(
- zero_tensor.shape, to_numpy(zero_tensor)
- ).get_output(0)
+ zero_tensor = network.add_constant(zero_tensor.shape, to_numpy(zero_tensor)).get_output(0)
running_sum = loop.add_recurrence(zero_tensor)
set_layer_name(running_sum, target, f"{name}_running_sum_1")
running_sum_tensor = running_sum.get_output(0)
@@ -3476,14 +3145,11 @@
name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
input_val = kwargs["input"]
if not isinstance(input_val, TRTTensor):
- raise RuntimeError(
- f"hardtanh received input {input_val} that is not part "
- "of the TensorRT region!"
- )
+ raise RuntimeError(f"hardtanh received input {input_val} that is not part " "of the TensorRT region!")
return add_activation_layer(
network,
input_val,
trt.ActivationType.CLIP,
@@ -3507,26 +3173,19 @@
scale_factor = kwargs["scale_factor"]
mode = kwargs["mode"]
align_corners = kwargs["align_corners"]
if not isinstance(input_val, TRTTensor):
- raise RuntimeError(
- f"interpolate received input {input_val} that is not part "
- "of the TensorRT region!"
- )
+ raise RuntimeError(f"interpolate received input {input_val} that is not part " "of the TensorRT region!")
dim = input_val.shape
ranks = len(input_val.shape)
if network.has_implicit_batch_dimension:
- assert (
- ranks >= 2 and ranks <= 4
- ), "Interpolate expects inputs are 3D,4D,5D in shape"
+ assert ranks >= 2 and ranks <= 4, "Interpolate expects inputs are 3D,4D,5D in shape"
ranks = ranks - 1
else:
- assert (
- ranks >= 3 and ranks <= 5
- ), "Interpolate expects inputs are 3D,4D,5D in shape"
+ assert ranks >= 3 and ranks <= 5, "Interpolate expects inputs are 3D,4D,5D in shape"
ranks = ranks - 2
layer = network.add_resize(input_val)
if network.has_implicit_batch_dimension:
if size != None:
@@ -3555,13 +3214,11 @@
layer.resize_mode = trt.ResizeMode.LINEAR
else:
layer.resize_mode = trt.ResizeMode.NEAREST
if align_corners != None:
- layer.coordinate_transformation = (
- trt.ResizeCoordinateTransformation.ALIGN_CORNERS
- )
+ layer.coordinate_transformation = trt.ResizeCoordinateTransformation.ALIGN_CORNERS
set_layer_name(layer, target, name)
return layer.get_output(0)
@@ -3579,13 +3236,11 @@
if dtype_val is None:
dtype_val = input_val.dtype
dtype_val = torch_dtype_from_trt(dtype_val)
device_val = kwargs.get("device")
- assert (
- device_val == "cuda" or device_val == None
- ), f"device is not `cuda` but {device_val}"
+ assert device_val == "cuda" or device_val == None, f"device is not `cuda` but {device_val}"
weight = torch.ones(size_val, dtype=dtype_val)
return get_trt_tensor(network, weight, f"{name}_weight")
@@ -3603,13 +3258,11 @@
if dtype_val is None:
dtype_val = input_val.dtype
dtype_val = torch_dtype_from_trt(dtype_val)
device_val = kwargs.get("device")
- assert (
- device_val == "cuda" or device_val == None
- ), f"device is not `cuda` but {device_val}"
+ assert device_val == "cuda" or device_val == None, f"device is not `cuda` but {device_val}"
weight = torch.zeros(size_val, dtype=dtype_val)
return get_trt_tensor(network, weight, f"{name}_weight")
@@ -3634,13 +3287,11 @@
input_val[i] = get_trt_tensor(network, input_source, name + f"_input_source{i}")
if const_flag:
for i, input_source in enumerate(input_val):
if input_source.dtype != trt.float32:
- input_val[i] = type_cast(
- network, target, f"{name}_input_cast{i}", input_source, trt.float32
- )
+ input_val[i] = type_cast(network, target, f"{name}_input_cast{i}", input_source, trt.float32)
einsum_layer = network.add_einsum(inputs=input_val, equation=equation)
return einsum_layer.get_output(0)
@tensorrt_converter(acc_ops.as_strided)
--- py/torch_tensorrt/fx/test/converters/acc_op/test_prod.py 2022-08-12 18:48:48.790461 +0000
+++ py/torch_tensorrt/fx/test/converters/acc_op/test_prod.py 2022-08-12 18:52:26.557706 +0000
@@ -70,13 +70,11 @@
inputs,
expected_ops={expected_acc_op},
test_implicit_batch_dim=(dim != 0),
)
- @parameterized.expand(
- [(f"{acc_ops.prod.__name__}_no_dim_no_keepdim", torch.prod, acc_ops.prod)]
- )
+ @parameterized.expand([(f"{acc_ops.prod.__name__}_no_dim_no_keepdim", torch.prod, acc_ops.prod)])
def test_prod_all_dims(
self,
test_name,
op,
expected_acc_op,
@@ -107,12 +105,10 @@
dtype=torch.float32,
shape_ranges=[((1, 1, 1, 1), (2, 3, 4, 5), (2, 3, 10, 10))],
),
]
- self.run_test_with_dynamic_shape(
- Prod(), input_specs, expected_ops={acc_ops.prod}
- )
+ self.run_test_with_dynamic_shape(Prod(), input_specs, expected_ops={acc_ops.prod})
if __name__ == "__main__":
run_tests()
--- py/torch_tensorrt/fx/test/converters/acc_op/test_reduce_ops.py 2022-08-12 18:48:48.790461 +0000
+++ py/torch_tensorrt/fx/test/converters/acc_op/test_reduce_ops.py 2022-08-12 18:52:26.565275 +0000
@@ -50,16 +50,11 @@
inputs,
expected_ops={expected_acc_op},
test_implicit_batch_dim=(dim != 0),
)
- @parameterized.expand(
- [
- (f"{acc_op.__name__}_no_dim_no_keepdim", op, acc_op)
- for op, acc_op in reduce_ops
- ]
- )
+ @parameterized.expand([(f"{acc_op.__name__}_no_dim_no_keepdim", op, acc_op) for op, acc_op in reduce_ops])
def test_reduce_all_dims(
self,
test_name,
op,
expected_acc_op,
@@ -74,16 +69,11 @@
inputs,
expected_ops={expected_acc_op},
test_implicit_batch_dim=False,
)
- @parameterized.expand(
- [
- (f"{acc_op.__name__}_no_dim_no_keepdim", op, acc_op)
- for op, acc_op in reduce_ops
- ]
- )
+ @parameterized.expand([(f"{acc_op.__name__}_no_dim_no_keepdim", op, acc_op) for op, acc_op in reduce_ops])
def test_reduce_all_dims_with_dynamic_shape_four_dimensions(
self,
test_name,
op,
expected_acc_op,
@@ -97,12 +87,10 @@
shape=(-1, -1, -1, -1),
dtype=torch.float32,
shape_ranges=[((1, 1, 1, 1), (3, 3, 3, 3), (3, 3, 3, 3))],
),
]
- self.run_test_with_dynamic_shape(
- Reduce(), input_specs, expected_ops={expected_acc_op}
- )
+ self.run_test_with_dynamic_shape(Reduce(), input_specs, expected_ops={expected_acc_op})
if __name__ == "__main__":
run_tests()
--- py/torch_tensorrt/fx/test/converters/acc_op/test_tile.py 2022-08-12 18:48:48.794461 +0000
+++ py/torch_tensorrt/fx/test/converters/acc_op/test_tile.py 2022-08-12 18:52:26.764386 +0000
@@ -26,14 +26,11 @@
inputs = [torch.randn(*input_shape)]
self.run_test(
Tile(dims),
inputs,
expected_ops={acc_ops.tile},
- test_implicit_batch_dim=(
- len(input_shape) > len(dims)
- or (len(input_shape) == len(dims) and dims[0] == 1)
- ),
+ test_implicit_batch_dim=(len(input_shape) > len(dims) or (len(input_shape) == len(dims) and dims[0] == 1)),
)
@parameterized.expand(
[
("same_num_dims", (-1, 2, 3), (1, 2, 2)),
@@ -62,13 +59,11 @@
tuple(i if i != -1 else 3 for i in shape),
)
],
),
]
- self.run_test_with_dynamic_shape(
- Tile(dims), input_specs, expected_ops={acc_ops.tile}
- )
+ self.run_test_with_dynamic_shape(Tile(dims), input_specs, expected_ops={acc_ops.tile})
@parameterized.expand(
[
("all_dynamic_dim", (-1, -1), (1, 2, 2, 1)),
]
@@ -88,13 +83,11 @@
dtype=torch.float32,
shape_ranges=[((1, 1, 1, 3), (3, 3, 3, 3), (3, 3, 3, 3))],
),
]
- self.run_test_with_dynamic_shape(
- Tile(dims), input_specs, expected_ops={acc_ops.tile}
- )
+ self.run_test_with_dynamic_shape(Tile(dims), input_specs, expected_ops={acc_ops.tile})
def test_tile_non_int_dims(self):
class Tile(nn.Module):
def __init__(self):
super().__init__()
@@ -103,13 +96,11 @@
y = y * 2
return torch.tile(x, (1, y.shape[1], y.shape[1]))
inputs = [torch.randn(2, 2, 3), torch.randn(2, 2, 3)]
batch_size_range = (1, 2, 3)
- input_specs = InputTensorSpec.from_tensors_with_dynamic_batch_size(
- inputs, batch_size_range
- )
+ input_specs = InputTensorSpec.from_tensors_with_dynamic_batch_size(inputs, batch_size_range)
self.run_test_with_dynamic_shape(
Tile(),
input_specs,
expected_ops={acc_ops.tile},
)
@@ -134,12 +125,10 @@
dtype=torch.float32,
shape_ranges=[((1, 1, 1, 3), (3, 3, 3, 3), (3, 3, 3, 3))],
),
]
- self.run_test_with_dynamic_shape(
- Tile(), input_specs, expected_ops={acc_ops.tile}
- )
+ self.run_test_with_dynamic_shape(Tile(), input_specs, expected_ops={acc_ops.tile})
if __name__ == "__main__":
run_tests()
--- py/torch_tensorrt/fx/test/converters/acc_op/test_to_dtype.py 2022-08-12 18:48:48.794461 +0000
+++ py/torch_tensorrt/fx/test/converters/acc_op/test_to_dtype.py 2022-08-12 18:52:26.814449 +0000
@@ -51,13 +51,11 @@
input = torch.randn(2, 2).to(torch.float16)
inputs = [
input,
]
- self.run_test(
- To(), inputs, expected_ops={acc_ops.to_dtype}, test_implicit_batch_dim=False
- )
+ self.run_test(To(), inputs, expected_ops={acc_ops.to_dtype}, test_implicit_batch_dim=False)
def test_cuda_fp16(self):
class To(torch.nn.Module):
def forward(self, x):
return x.to(torch.device("cuda:0"), torch.float16)
@@ -106,13 +104,11 @@
dtype=torch.float16,
shape_ranges=[((1, 1, 3, 3), (3, 3, 3, 3), (3, 3, 3, 3))],
),
]
- self.run_test_with_dynamic_shape(
- To(), input_specs, expected_ops={acc_ops.to_dtype, acc_ops.add}
- )
+ self.run_test_with_dynamic_shape(To(), input_specs, expected_ops={acc_ops.to_dtype, acc_ops.add})
def test_device(self):
class To(torch.nn.Module):
def __init__(self):
super().__init__()
@@ -152,13 +148,11 @@
dtype=torch.float16,
shape_ranges=[((1, 1, 3, 3), (3, 3, 3, 3), (3, 3, 3, 3))],
),
]
- self.run_test_with_dynamic_shape(
- To(), input_specs, expected_ops={acc_ops.to_dtype, acc_ops.add}
- )
+ self.run_test_with_dynamic_shape(To(), input_specs, expected_ops={acc_ops.to_dtype, acc_ops.add})
def test_device_fp16(self):
class To(torch.nn.Module):
def __init__(self):
super().__init__()
@@ -244,13 +238,11 @@
dtype=torch.float32,
shape_ranges=[((1, 1, 1, 1), (3, 3, 3, 3), (3, 3, 3, 3))],
),
]
- self.run_test_with_dynamic_shape(
- To(), input_specs, expected_ops={acc_ops.to_dtype}
- )
+ self.run_test_with_dynamic_shape(To(), input_specs, expected_ops={acc_ops.to_dtype})
# Half is not suitable for dynamic shape
# Error: assert engine
# tensor.half()
@@ -307,12 +299,10 @@
dtype=torch.int,
shape_ranges=[((1, 1, 1, 1), (3, 3, 3, 3), (3, 3, 3, 3))],
),
]
- self.run_test_with_dynamic_shape(
- To(), input_specs, expected_ops={acc_ops.to_dtype}
- )
+ self.run_test_with_dynamic_shape(To(), input_specs, expected_ops={acc_ops.to_dtype})
if __name__ == "__main__":
run_tests()
--- py/torch_tensorrt/fx/test/converters/acc_op/test_topk.py 2022-08-12 18:48:48.794461 +0000
+++ py/torch_tensorrt/fx/test/converters/acc_op/test_topk.py 2022-08-12 18:52:26.820155 +0000
@@ -24,13 +24,11 @@
self.dim = dim
self.largest = largest
def forward(self, x):
if self.dim is not None:
- out = torch.topk(
- x, k=self.k, dim=self.dim, largest=self.largest, sorted=False
- )
+ out = torch.topk(x, k=self.k, dim=self.dim, largest=self.largest, sorted=False)
else:
out = torch.topk(x, k=self.k, largest=self.largest, sorted=False)
return out[0], out[1]
inputs = [torch.randn(1, 2, 3, 4)]
@@ -58,13 +56,11 @@
self.dim = dim
self.largest = largest
def forward(self, x):
if self.dim is not None:
- out = torch.topk(
- x, k=self.k, dim=self.dim, largest=self.largest, sorted=False
- )
+ out = torch.topk(x, k=self.k, dim=self.dim, largest=self.largest, sorted=False)
else:
out = torch.topk(x, k=self.k, largest=self.largest, sorted=False)
return out[0], out[1]
input_specs = [
@@ -73,12 +69,10 @@
dtype=torch.float32,
shape_ranges=[((1, 1, 1, 1), (3, 3, 3, 3), (3, 3, 3, 3))],
),
]
- self.run_test_with_dynamic_shape(
- TopK(k, dim), input_specs, expected_ops={acc_ops.topk}
- )
+ self.run_test_with_dynamic_shape(TopK(k, dim), input_specs, expected_ops={acc_ops.topk})
if __name__ == "__main__":
run_tests()
--- py/torch_tensorrt/fx/test/converters/acc_op/test_unary_ops.py 2022-08-12 18:48:48.794461 +0000
+++ py/torch_tensorrt/fx/test/converters/acc_op/test_unary_ops.py 2022-08-12 18:52:26.956415 +0000
@@ -62,13 +62,11 @@
dtype=torch.float32,
shape_ranges=[((1, 1, 1, 1), (3, 3, 3, 3), (3, 3, 3, 3))],
),
]
- self.run_test_with_dynamic_shape(
- TestModule(orig_op), input_specs, expected_ops={expected_op}
- )
+ self.run_test_with_dynamic_shape(TestModule(orig_op), input_specs, expected_ops={expected_op})
class TestUnaryOpNotConverters(AccTestCase):
@parameterized.expand(
[
@@ -87,13 +85,11 @@
x = self.orig_op(x)
return self.orig_op(x)
m = TestModule(orig_op)
inputs = [torch.randn(2, 2, 3).to(input_dtype)]
- self.run_test(
- m, inputs, expected_ops={expected_op}, test_implicit_batch_dim=False
- )
+ self.run_test(m, inputs, expected_ops={expected_op}, test_implicit_batch_dim=False)
class TestUnaryOpNotConvertersWithDynamicShapeFourDimensions(AccTestCase):
@parameterized.expand(
[
@@ -118,13 +114,11 @@
dtype=torch.float32,
shape_ranges=[((1, 1, 1, 1), (3, 3, 3, 3), (3, 3, 3, 3))],
),
]
- self.run_test_with_dynamic_shape(
- TestModule(orig_op), input_specs, expected_ops={expected_op}
- )
+ self.run_test_with_dynamic_shape(TestModule(orig_op), input_specs, expected_ops={expected_op})
class TestUnaryRSQRTConverters(AccTestCase):
def test_unary_ops(self):
class TestModule(nn.Module):
@@ -148,12 +142,10 @@
dtype=torch.float32,
shape_ranges=[((1, 1, 1, 1), (3, 3, 3, 3), (3, 3, 3, 3))],
),
]
- self.run_test_with_dynamic_shape(
- TestModule(), input_specs, expected_ops={acc_ops.sqrt, acc_ops.reciprocal}
- )
+ self.run_test_with_dynamic_shape(TestModule(), input_specs, expected_ops={acc_ops.sqrt, acc_ops.reciprocal})
if __name__ == "__main__":
run_tests()
--- py/torch_tensorrt/fx/test/core/test_input_tensor_spec.py 2022-08-12 18:48:48.794461 +0000
+++ py/torch_tensorrt/fx/test/core/test_input_tensor_spec.py 2022-08-12 18:52:26.977250 +0000
@@ -35,26 +35,22 @@
self._validate_spec(spec, tensor)
def test_from_tensors_with_dynamic_batch_size(self):
tensors = [torch.randn(1, 2, 3), torch.randn(1, 4)]
batch_size_range = [2, 3, 4]
- specs = InputTensorSpec.from_tensors_with_dynamic_batch_size(
- tensors, batch_size_range
- )
+ specs = InputTensorSpec.from_tensors_with_dynamic_batch_size(tensors, batch_size_range)
for spec, tensor in zip(specs, tensors):
self._validate_spec(spec, tensor, dynamic_dims=[0])
for batch_size, shape in zip(batch_size_range, spec.shape_ranges[0]):
self.assertEqual(batch_size, shape[0])
self.assertSequenceEqual(tensor.shape[1:], shape[1:])
def test_from_tensors_with_dynamic_batch_size_different_batch_dims(self):
tensors = [torch.randn(1, 2, 3), torch.randn(2, 1, 4)]
batch_size_range = [2, 3, 4]
- specs = InputTensorSpec.from_tensors_with_dynamic_batch_size(
- tensors, batch_size_range, batch_dims=[0, 1]
- )
+ specs = InputTensorSpec.from_tensors_with_dynamic_batch_size(tensors, batch_size_range, batch_dims=[0, 1])
for i, spec_and_tensor in enumerate(zip(specs, tensors)):
spec, tensor = spec_and_tensor
self._validate_spec(spec, tensor, dynamic_dims=[i])
for batch_size, shape in zip(batch_size_range, spec.shape_ranges[0]):
@@ -62,13 +58,11 @@
tensor_shape = list(tensor.shape)
tensor_shape[i] = batch_size
self.assertSequenceEqual(tensor_shape, shape)
def test_generate_input_specs(self):
- lower_setting = LowerSetting(
- explicit_batch_dimension=False, max_batch_size=256, opt_profile_replica=2
- )
+ lower_setting = LowerSetting(explicit_batch_dimension=False, max_batch_size=256, opt_profile_replica=2)
# Implicit batch dim.
inputs = [torch.randn(1, 2, 3)]
specs = generate_input_specs(inputs, lower_setting)
for spec, tensor in zip(specs, inputs):
--- py/torch_tensorrt/fx/test/quant/test_quant_trt.py 2022-08-12 18:48:48.794461 +0000
+++ py/torch_tensorrt/fx/test/quant/test_quant_trt.py 2022-08-12 18:52:27.493933 +0000
@@ -46,13 +46,11 @@
shape_ranges=shape_ranges,
has_batch_dim=True,
)
]
- interp = TRTInterpreter(
- model, input_specs, explicit_batch_dimension=True, explicit_precision=True
- )
+ interp = TRTInterpreter(model, input_specs, explicit_batch_dimension=True, explicit_precision=True)
result = interp.run(lower_precision=LowerPrecision.INT8)
trt_mod = TRTModule(result.engine, result.input_names, result.output_names)
return trt_mod
@@ -65,13 +63,11 @@
),
weight=torch.ao.quantization.default_weight_observer,
)
self.trt_backend_config_dict = get_tensorrt_backend_config_dict()
- def _test_quantized_inputs_outputs(
- self, prepare_custom_config_dict, prepare_count_check, convert_count_check
- ):
+ def _test_quantized_inputs_outputs(self, prepare_custom_config_dict, prepare_count_check, convert_count_check):
"""
Test the option to have inputs and outputs of the graph quantized
"""
class M(torch.nn.Module):
@@ -113,13 +109,11 @@
# output of ref conv1 and output of ref conv2
ns.call_function(torch.quantize_per_tensor): 2,
# input of ref conv1 and input of ref conv2
ns.call_method("dequantize"): 2,
}
- self._test_quantized_inputs_outputs(
- prepare_custom_config_dict, prepare_count_check, convert_count_check
- )
+ self._test_quantized_inputs_outputs(prepare_custom_config_dict, prepare_count_check, convert_count_check)
def test_fp32_input_quantized_output(self):
prepare_custom_config_dict = {"output_quantized_idxs": [0]}
prepare_count_check = {
ns.call_module(torch.ao.quantization.MinMaxObserver): 3,
@@ -128,13 +122,11 @@
# input, output of conv1 and output of conv2
ns.call_function(torch.quantize_per_tensor): 3,
# input of conv1, conv2
ns.call_method("dequantize"): 2,
}
- self._test_quantized_inputs_outputs(
- prepare_custom_config_dict, prepare_count_check, convert_count_check
- )
+ self._test_quantized_inputs_outputs(prepare_custom_config_dict, prepare_count_check, convert_count_check)
def test_quantized_input_fp32_output(self):
prepare_custom_config_dict = {"input_quantized_idxs": [0]}
prepare_count_check = {
ns.call_module(torch.ao.quantization.MinMaxObserver): 2,
@@ -143,26 +135,22 @@
# output of conv1, conv2
ns.call_function(torch.quantize_per_tensor): 2,
# input of ref conv1, input of ref conv2, final output
ns.call_method("dequantize"): 3,
}
- self._test_quantized_inputs_outputs(
- prepare_custom_config_dict, prepare_count_check, convert_count_check
- )
+ self._test_quantized_inputs_outputs(prepare_custom_config_dict, prepare_count_check, convert_count_check)
def test_fp32_input_fp32_output(self):
prepare_custom_config_dict = {}
prepare_count_check = {
ns.call_module(torch.ao.quantization.MinMaxObserver): 3,
}
convert_count_check = {
ns.call_function(torch.quantize_per_tensor): 3,
ns.call_method("dequantize"): 3,
}
- self._test_quantized_inputs_outputs(
- prepare_custom_config_dict, prepare_count_check, convert_count_check
- )
+ self._test_quantized_inputs_outputs(prepare_custom_config_dict, prepare_count_check, convert_count_check)
def _test_standalone_module(
self,
interface_config,
prepare_count_check,
@@ -213,20 +201,14 @@
data = torch.randn(1, 1, 1, 1)
# instantiate M and RefM and align the parameters
original_m = M().eval()
original_ref_m = RefM().eval()
- original_ref_m.conv1.weight = torch.nn.Parameter(
- original_m.conv.weight.detach()
- )
+ original_ref_m.conv1.weight = torch.nn.Parameter(original_m.conv.weight.detach())
original_ref_m.conv1.bias = torch.nn.Parameter(original_m.conv.bias.detach())
- original_ref_m.conv2.weight = torch.nn.Parameter(
- original_m.standalone.conv.weight.detach()
- )
- original_ref_m.conv2.bias = torch.nn.Parameter(
- original_m.standalone.conv.bias.detach()
- )
+ original_ref_m.conv2.weight = torch.nn.Parameter(original_m.standalone.conv.weight.detach())
+ original_ref_m.conv2.bias = torch.nn.Parameter(original_m.standalone.conv.bias.detach())
sm_example_inputs = (data,)
prepare_config = {
"standalone_module_name": [
(
@@ -253,20 +235,16 @@
backend_config=backend_config_dict,
)
# calibration
m(data)
self.checkGraphModuleNodes(m, expected_node_occurrence=prepare_count_check)
- self.checkGraphModuleNodes(
- m.standalone, expected_node_occurrence=standalone_prepare_count_check
- )
+ self.checkGraphModuleNodes(m.standalone, expected_node_occurrence=standalone_prepare_count_check)
# check converted/quantized model
m = convert_to_reference_fx(m, backend_config=backend_config_dict)
self.checkGraphModuleNodes(m, expected_node_occurrence=convert_count_check)
- self.checkGraphModuleNodes(
- m.standalone, expected_node_occurrence=standalone_convert_count_check
- )
+ self.checkGraphModuleNodes(m.standalone, expected_node_occurrence=standalone_convert_count_check)
res = m(data)
# quantize the reference model
ref_m = prepare_fx(
original_ref_m_copy,
@@ -285,17 +263,13 @@
"output_quantized_idxs": [], # float output
}
interface_config = float_interface_config
# input and output of first conv, observer for standalone module
# will be inserted in the standalone module itself
- prepare_count_check = {
- ns.call_module(torch.ao.quantization.HistogramObserver): 2
- }
+ prepare_count_check = {ns.call_module(torch.ao.quantization.HistogramObserver): 2}
# for input and output of conv in the standalone module
- standalone_prepare_count_check = {
- ns.call_module(torch.ao.quantization.HistogramObserver): 2
- }
+ standalone_prepare_count_check = {ns.call_module(torch.ao.quantization.HistogramObserver): 2}
convert_count_check = {
# input and output of reference conv
ns.call_function(torch.quantize_per_tensor): 2,
ns.call_module(nnqr.Conv2d): 1,
ns.call_method("dequantize"): 2,
@@ -351,17 +325,13 @@
"root_module": torch.nn.Conv2d,
"reference_quantized_module_for_root": torch.nn.quantized._reference.Conv2d,
}
custom_backend_config_dict = {"configs": [conv_module_config]}
# observer for input and output of first conv
- prepare_count_check = {
- ns.call_module(torch.ao.quantization.HistogramObserver): 2
- }
+ prepare_count_check = {ns.call_module(torch.ao.quantization.HistogramObserver): 2}
# for output of conv in the standalone module
- standalone_prepare_count_check = {
- ns.call_module(torch.ao.quantization.HistogramObserver): 1
- }
+ standalone_prepare_count_check = {ns.call_module(torch.ao.quantization.HistogramObserver): 1}
convert_count_check = {
# quantizing input/output for reference conv
ns.call_function(torch.quantize_per_tensor): 2,
ns.call_module(nnqr.Conv2d): 1,
# dequantize the input of reference conv and
@@ -400,13 +370,11 @@
),
weight=torch.ao.quantization.default_weight_observer,
)
self.trt_backend_config_dict = get_tensorrt_backend_config_dict()
- def _test_module(
- self, m, inputs, shape_ranges, no_prepare=None, no_convert=None, is_qat=False
- ):
+ def _test_module(self, m, inputs, shape_ranges, no_prepare=None, no_convert=None, is_qat=False):
"""
Args:
m: the float module we want to test
inputs: list of inputs for the module
shape_ranges: a list of shape_range, where every shape_range is a tuple of
@@ -468,13 +436,11 @@
def forward(self, x):
return self.relu(self.conv(x))
# just testing conv2d since conv1d and conv3d are not supported in fx2trt
- for dim, has_relu, f_relu, is_qat in itertools.product(
- [1, 2], [True, False], [True, False], [True, False]
- ):
+ for dim, has_relu, f_relu, is_qat in itertools.product([1, 2], [True, False], [True, False], [True, False]):
# when has_relu=False, we have torch.nn.Identity, which would introduce
# extra quant-dequat pair
no_convert = {
ns.call_function(torch.quantize_per_tensor): 2 + int(not has_relu),
ns.call_method("dequantize"): 2 + int(not has_relu),
@@ -510,13 +476,11 @@
return self.relu(self.linear(x))
linear_input = torch.rand(8, 5)
shape_ranges = [((1, 5), (5, 5), (10, 5))]
- for has_relu, f_relu, is_qat in itertools.product(
- [True, False], [True, False], [True, False]
- ):
+ for has_relu, f_relu, is_qat in itertools.product([True, False], [True, False], [True, False]):
# when has_relu=False, we have torch.nn.Identity, which would introduce
# extra quant-dequat pair
no_convert = {
ns.call_function(torch.quantize_per_tensor): 2 + int(not has_relu),
ns.call_method("dequantize"): 2 + int(not has_relu),
@@ -662,13 +626,11 @@
ns.call_function(torch.addmm): 1,
ns.call_method("dequantize"): 3,
}
self.checkGraphModuleNodes(quantized, expected_node_occurrence=node_occurrence)
- @unittest.skip(
- "This is not supported yet, we can enable the test after it's supported"
- )
+ @unittest.skip("This is not supported yet, we can enable the test after it's supported")
def test_conv_add(self):
class M(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv = torch.nn.Conv2d(3, 3, 3)
@@ -828,13 +790,11 @@
self.checkGraphModuleNodes(m, expected_node_occurrence=node_occurrence)
standalone_node_occurrence = {
# output of the standalone module
ns.call_module(torch.ao.quantization.HistogramObserver): 1,
}
- self.checkGraphModuleNodes(
- m.standalone, expected_node_occurrence=standalone_node_occurrence
- )
+ self.checkGraphModuleNodes(m.standalone, expected_node_occurrence=standalone_node_occurrence)
m = convert_to_reference_fx(m, backend_config=backend_config_dict)
node_occurrence = {
# two inputs for standalone module
ns.call_function(torch.quantize_per_tensor): 2,
ns.call_module(nn.Conv2d): 1,
@@ -847,13 +807,11 @@
ns.call_module(nn.Conv2d): 1,
ns.call_module(torch.nn.ReLU): 1,
# two input and one output for the pattern in standalone module
ns.call_method("dequantize"): 3,
}
- self.checkGraphModuleNodes(
- m.standalone, expected_node_occurrence=standalone_node_occurrence
- )
+ self.checkGraphModuleNodes(m.standalone, expected_node_occurrence=standalone_node_occurrence)
def test_quant_dequant_not_fold(self):
class LinearModule(torch.nn.Module):
def __init__(self):
super().__init__()
--- py/torch_tensorrt/fx/tools/common_fx2trt.py 2022-08-12 18:48:48.794461 +0000
+++ py/torch_tensorrt/fx/tools/common_fx2trt.py 2022-08-12 18:52:27.791019 +0000
@@ -29,13 +29,11 @@
"""
target_atoms = target.split(".")
attr_itr = mod
for i, atom in enumerate(target_atoms):
if not hasattr(attr_itr, atom):
- raise RuntimeError(
- f"Node referenced nonexistent target {'.'.join(target_atoms[:i])}"
- )
+ raise RuntimeError(f"Node referenced nonexistent target {'.'.join(target_atoms[:i])}")
attr_itr = getattr(attr_itr, atom)
return attr_itr
@unittest.skipIf(not torch.cuda.is_available(), "Skip because CUDA is not available")
@@ -82,13 +80,11 @@
end_event = torch.cuda.Event(enable_timing=True)
start_event.record()
outputs = trt_mod(*cuda_inputs)
end_event.record()
torch.cuda.synchronize()
- _LOGGER.info(
- f"TRT run time(s)= {(start_event.elapsed_time(end_event) * 1.0e-3)}"
- )
+ _LOGGER.info(f"TRT run time(s)= {(start_event.elapsed_time(end_event) * 1.0e-3)}")
if isinstance(outputs, torch.Tensor):
ref_outputs = [ref_outputs]
outputs = [outputs]
for out, ref in zip(outputs, ref_outputs):
@@ -126,26 +122,22 @@
mod.eval()
if len(expected_ops):
self.assert_has_op(mod, expected_ops)
interpreter_result = interpreter.run(
- lower_precision=LowerPrecision.FP16
- if fp16_mode
- else LowerPrecision.FP32
+ lower_precision=LowerPrecision.FP16 if fp16_mode else LowerPrecision.FP32
)
trt_mod = TRTModule(
interpreter_result.engine,
interpreter_result.input_names,
interpreter_result.output_names,
)
res_trt = trt_mod(*cuda_inputs).cpu()
res_cpu = mod(*inputs)
assert len(res_trt) == len(res_cpu)
assert len(res_cpu) == len(comparators)
- for output_trt, output_cpu, comparator in zip(
- res_trt, res_cpu, comparators
- ):
+ for output_trt, output_cpu, comparator in zip(res_trt, res_cpu, comparators):
comp_func = comparator[0]
args = comparator[1]
self.assertTrue(comp_func(output_trt, output_cpu, *args))
def run_test_with_error(self, mod, inputs, interpreter, expect_error):
@@ -165,13 +157,11 @@
if node.op == "call_module":
ops_in_mod.add(type(fetch_attr(mod, node.target)))
elif node.op in {"call_function", "call_method"}:
ops_in_mod.add(node.target)
- self.assertTrue(
- ops_in_mod >= ops, f"expected ops {ops}, actuall ops {ops_in_mod}"
- )
+ self.assertTrue(ops_in_mod >= ops, f"expected ops {ops}, actuall ops {ops_in_mod}")
def assert_unexpected_op(self, mod, ops):
for node in mod.graph.nodes:
if node.op == "call_module":
if type(fetch_attr(mod, node.target)) in ops:
@@ -204,13 +194,11 @@
# after we refactor the internal callsites to use this file
mod = torch.fx.symbolic_trace(mod)
shape_prop.ShapeProp(mod).propagate(*inputs)
mod = NormalizeArgs(mod).transform()
interp = TRTInterpreter(mod, InputTensorSpec.from_tensors(inputs))
- super().run_test_custom_compare_results(
- mod, inputs, expected_ops, interp, comparators, fp16_mode=fp16_mode
- )
+ super().run_test_custom_compare_results(mod, inputs, expected_ops, interp, comparators, fp16_mode=fp16_mode)
class AccTestCase(TRTTestCase):
def run_test(
self,
@@ -233,41 +221,31 @@
pass_tracer = chain_passes(*apply_passes)
mod = pass_tracer(mod, inputs)
if test_implicit_batch_dim:
interp = TRTInterpreter(mod, InputTensorSpec.from_tensors(inputs))
- super().run_test(
- mod, inputs, expected_ops, unexpected_ops, interp, rtol, atol, precision
- )
+ super().run_test(mod, inputs, expected_ops, unexpected_ops, interp, rtol, atol, precision)
if test_explicit_batch_dim:
- interp = TRTInterpreter(
- mod, InputTensorSpec.from_tensors(inputs), explicit_batch_dimension=True
- )
- super().run_test(
- mod, inputs, expected_ops, unexpected_ops, interp, rtol, atol, precision
- )
+ interp = TRTInterpreter(mod, InputTensorSpec.from_tensors(inputs), explicit_batch_dimension=True)
+ super().run_test(mod, inputs, expected_ops, unexpected_ops, interp, rtol, atol, precision)
if test_explicit_precision:
interp = TRTInterpreter(
mod,
InputTensorSpec.from_tensors(inputs),
explicit_precision=test_explicit_precision,
)
- super().run_test(
- mod, inputs, expected_ops, unexpected_ops, interp, rtol, atol
- )
+ super().run_test(mod, inputs, expected_ops, unexpected_ops, interp, rtol, atol)
interp = TRTInterpreter(
mod,
InputTensorSpec.from_tensors(inputs),
explicit_batch_dimension=True,
explicit_precision=test_explicit_precision,
)
- super().run_test(
- mod, inputs, expected_ops, unexpected_ops, interp, rtol, atol, precision
- )
+ super().run_test(mod, inputs, expected_ops, unexpected_ops, interp, rtol, atol, precision)
def run_test_with_assert_error(
self,
mod,
inputs,
@@ -281,13 +259,11 @@
if test_implicit_batch_dim:
interp = TRTInterpreter(mod, InputTensorSpec.from_tensors(inputs))
super().run_test_with_error(mod, inputs, interp, expect_error)
if test_explicit_batch_dim:
- interp = TRTInterpreter(
- mod, InputTensorSpec.from_tensors(inputs), explicit_batch_dimension=True
- )
+ interp = TRTInterpreter(mod, InputTensorSpec.from_tensors(inputs), explicit_batch_dimension=True)
super().run_test_with_error(mod, inputs, interp, expect_error)
def run_test_with_dynamic_shape(
self,
mod,
--- py/torch_tensorrt/fx/tools/trt_splitter.py 2022-08-12 18:48:48.794461 +0000
+++ py/torch_tensorrt/fx/tools/trt_splitter.py 2022-08-12 18:52:27.864356 +0000
@@ -72,13 +72,11 @@
operator_support,
settings,
non_acc_submodule_name="_run_on_gpu_",
)
- def _lower_model_to_backend(
- self, mod: torch.fx.GraphModule, inputs: Iterable[torch.Tensor]
- ):
+ def _lower_model_to_backend(self, mod: torch.fx.GraphModule, inputs: Iterable[torch.Tensor]):
"""
Lower a GraphModule `mod` to TensorRT with `inputs`.
"""
# Current code for lowering is place-holder, subject to future change
# based on feeds model's actual status
--- py/torch_tensorrt/fx/tools/trt_minimizer.py 2022-08-12 18:48:48.794461 +0000
+++ py/torch_tensorrt/fx/tools/trt_minimizer.py 2022-08-12 18:52:27.879445 +0000
@@ -8,16 +8,12 @@
from .. import InputTensorSpec, TRTInterpreter, TRTModule
_LOGGER: logging.Logger = logging.getLogger(__name__)
-def lower_mod_default(
- mod: torch.fx.GraphModule, inputs: Tensors, batch_size: Any = 2048
-) -> TRTModule:
- interp = TRTInterpreter(
- mod, InputTensorSpec.from_tensors(inputs), explicit_batch_dimension=True
- )
+def lower_mod_default(mod: torch.fx.GraphModule, inputs: Tensors, batch_size: Any = 2048) -> TRTModule:
+ interp = TRTInterpreter(mod, InputTensorSpec.from_tensors(inputs), explicit_batch_dimension=True)
interpreter_result = interp.run(max_batch_size=batch_size)
res_mod = TRTModule(
interpreter_result.engine,
interpreter_result.input_names,
interpreter_result.output_names,
@@ -37,13 +33,11 @@
module: torch.fx.GraphModule,
sample_input: Tensors,
compare_fn: Callable[[Any, Any, Any], Tuple[float, bool]],
settings: TensorRTMinizerSetting = TensorRTMinizerSetting(),
max_batch_size: Any = 2048,
- lower_fn: Callable[
- [torch.fx.GraphModule, Tensors, Any], TRTModule
- ] = lower_mod_default,
+ lower_fn: Callable[[torch.fx.GraphModule, Tensors, Any], TRTModule] = lower_mod_default,
):
self.lower_fn = lower_fn
self.max_batch_size = max_batch_size
super().__init__(module, sample_input, compare_fn, settings)
@@ -56,13 +50,11 @@
mod.eval()
try:
mod = self.lower_fn(mod, inputs, self.max_batch_size)
output = mod(*inputs)
except RuntimeError as e:
- raise net_min_base.FxNetMinimizerRunFuncError(
- f"Encounter an error when processing \n{mod.graph}\n {e}"
- )
+ raise net_min_base.FxNetMinimizerRunFuncError(f"Encounter an error when processing \n{mod.graph}\n {e}")
else:
return output
def get_nodes(self, start=None, end=None, enable_print=False):
nodes = self._collect_nodes(start, end)
--- py/torch_tensorrt/fx/tracer/acc_tracer/acc_tracer.py 2022-08-12 18:48:48.794461 +0000
+++ py/torch_tensorrt/fx/tracer/acc_tracer/acc_tracer.py 2022-08-12 18:52:28.165165 +0000
@@ -41,13 +41,11 @@
def __init__(self):
super().__init__()
self.exceptions_rewritten: Set[Type[Exception]] = set()
self.exceptions_bool_rewritten: Set[Type[Exception]] = set()
- def rewrite(
- self, fn: FunctionType
- ) -> Tuple[FunctionType, Set[Type[Exception]], Set[Type[Exception]]]:
+ def rewrite(self, fn: FunctionType) -> Tuple[FunctionType, Set[Type[Exception]], Set[Type[Exception]]]:
# Normalize the source lines
sourcelines, _ = inspect.getsourcelines(fn)
sourcelines = normalize_source_lines(sourcelines)
source = "".join(sourcelines)
@@ -139,12 +137,11 @@
return if_node
# Check that we actually have a builtin exception.
if (
not issubclass(exc_type, Exception)
- or getattr(getattr(exc_type, "__class__", None), "__module__", None)
- != "builtins"
+ or getattr(getattr(exc_type, "__class__", None), "__module__", None) != "builtins"
):
return if_node
# We need a ConditionalExceptionWrapper specialized for every kind of
# exception, so add it to exceptions_rewritten to remember for later to
@@ -156,23 +153,17 @@
# the If with, with args set as the If's condition and the string of the
# exception. The call to the self._conditional_exception_wrapper_*Error
# module is safe because the RewrittenModule will add it as an attr
# based on the returned exceptions_rewritten, and we assume we are
# currently modifying the AST of a method from a RewrittenModule.
- exc_wrapper_node = ast.parse(
- f"self.{_get_exception_wrapper_attr_name(exc_type)}()", mode="eval"
- )
+ exc_wrapper_node = ast.parse(f"self.{_get_exception_wrapper_attr_name(exc_type)}()", mode="eval")
assert isinstance(exc_wrapper_node, ast.Expression)
exc_wrapper_call_node = exc_wrapper_node.body
assert isinstance(exc_wrapper_call_node, ast.Call)
- if isinstance(if_node.test, ast.BoolOp) and isinstance(
- if_node.test.op, ast.And
- ):
+ if isinstance(if_node.test, ast.BoolOp) and isinstance(if_node.test.op, ast.And):
self.exceptions_bool_rewritten.add(exc_type)
- bool_wrapper_node = ast.parse(
- f"self.{_get_exception_wrapper_attr_name(exc_type)}_bool()", mode="eval"
- )
+ bool_wrapper_node = ast.parse(f"self.{_get_exception_wrapper_attr_name(exc_type)}_bool()", mode="eval")
assert isinstance(exc_wrapper_node, ast.Expression)
bool_wrapper_call_node = bool_wrapper_node.body
assert isinstance(exc_wrapper_call_node, ast.Call)
bool_wrapper_call_node.args = if_node.test.values
exc_wrapper_call_node.args = [
@@ -323,13 +314,11 @@
name_target[-1] == "_"
and name_target[0] != "_"
and not (name_target in allow_list)
and kind != "placeholder"
):
- raise RuntimeError(
- f"Tried to trace mutable operation {name_target}. FX only supports functional code"
- )
+ raise RuntimeError(f"Tried to trace mutable operation {name_target}. FX only supports functional code")
return self.graph.create_node(kind, target, args, kwargs, name, type_expr)
# List of modules that need rewriting to be supported for tracing.
@@ -384,13 +373,11 @@
# Write all of the non-dunder or special methods from base_class
# into RewrittenModule.
for method_name in dir(base_class):
method = getattr(base_class, method_name, None)
if method is None and method_name not in {"__doc__"}:
- _LOGGER.warning(
- f"{__qualname__} does not have attribute {method_name}"
- )
+ _LOGGER.warning(f"{__qualname__} does not have attribute {method_name}")
if builtins.type(method) is not FunctionType:
continue
# Always skip rewriting dunder methods, as they haven't (yet) been
@@ -437,13 +424,11 @@
# Recursively rewrite and copy all module attrs of this module.
for k, v in orig.__dict__.items():
if k == "_modules":
for mod_k, mod_v in v.items():
if getattr(mod_v, "_base_class_origin", type(mod_v)) in leaf_module_list: # type: ignore[operator]
- _LOGGER.info(
- f"Skip rewriting leaf module {type(mod_v)}"
- )
+ _LOGGER.info(f"Skip rewriting leaf module {type(mod_v)}")
self._modules[mod_k] = mod_v
else:
self._modules[mod_k] = rewrite_module(mod_v)
else:
self.__dict__[k] = v
@@ -475,25 +460,21 @@
"""
changed = False
for node in reversed(gm.graph.nodes):
if node.op == "call_module" and (
isinstance(gm.get_submodule(node.target), ConditionalExceptionWrapper)
- or isinstance(
- gm.get_submodule(node.target), ConditionalExceptionBoolCondWrapper
- )
+ or isinstance(gm.get_submodule(node.target), ConditionalExceptionBoolCondWrapper)
):
gm.graph.erase_node(node)
changed = True
return changed
def _replace_tensor_meta_with_rank(gm: torch.fx.GraphModule):
for node in gm.graph.nodes:
if node.op != "output" and "tensor_meta" in node.meta:
- node.meta["tensor_rank"] = acc_utils.map_tensor_metadata(
- node.meta["tensor_meta"], lambda x: len(x.shape)
- )
+ node.meta["tensor_rank"] = acc_utils.map_tensor_metadata(node.meta["tensor_meta"], lambda x: len(x.shape))
del node.meta["tensor_meta"]
def rewriter_base_trace(mod, ast_rewriter_allow_list, leaf_module_list):
rewritten_graph, rewritten_mod = AccRewritingTracer().trace(
@frank-wei can you share the fb lint config or something so that we can use a consistent code style? |
They are using black but looks like it is more than that. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code conforms to C++ style guidelines
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There are some changes that do not conform to Python style guidelines:
--- py/torch_tensorrt/fx/input_tensor_spec.py 2022-08-12 19:16:11.712868 +0000
+++ py/torch_tensorrt/fx/input_tensor_spec.py 2022-08-12 19:19:58.709915 +0000
@@ -6,14 +6,11 @@
from .utils import get_dynamic_dims
def generate_input_specs(inputs, lower_setting, additional_inputs=None):
# dynamic_batch is TRT only flag.
- if (
- not lower_setting.explicit_batch_dimension
- or lower_setting.dynamic_batch is False
- ):
+ if not lower_setting.explicit_batch_dimension or lower_setting.dynamic_batch is False:
return InputTensorSpec.from_tensors(inputs)
# If we don't have additional inputs, we assume the first dimension
# is the dynamic batch dimension. Otherwise, we use the additional
# inputs to determine the batch dimension.
@@ -33,20 +30,16 @@
for i, j in zip(inputs, additional_inputs):
found_batch_dim = False
for idx, values in enumerate(zip(i.shape, j.shape)):
if values[0] != values[1]:
- assert (
- found_batch_dim is False
- ), f"We've already found a batch dim, {i.shape}, {j.shape}."
+ assert found_batch_dim is False, f"We've already found a batch dim, {i.shape}, {j.shape}."
batch_dims.append(idx)
found_batch_dim = True
if not found_batch_dim:
- raise RuntimeError(
- f"Failed to find batch dimension because shapes are the same, {i.shape}"
- )
+ raise RuntimeError(f"Failed to find batch dimension because shapes are the same, {i.shape}")
return InputTensorSpec.from_tensors_with_dynamic_batch_size(
inputs,
(
0,
@@ -158,13 +151,11 @@
batch_dim
), f"The {i}th tensor (shape: {tensor.shape}) doesn't have the correct batch size: {batch_size}."
shape = list(tensor.shape)
shape[batch_dim] = -1
shape_ranges: List[ShapeRange] = [tuple(tuple(shape[0:batch_dim] + [bs] + shape[batch_dim + 1 :]) for bs in batch_size_range)] * opt_profile_replica # type: ignore[list-item]
- input_specs.append(
- cls(tuple(shape), tensor.dtype, tensor.device, shape_ranges)
- )
+ input_specs.append(cls(tuple(shape), tensor.dtype, tensor.device, shape_ranges))
return input_specs
def to_random_tensor(self):
shape = tuple(self.shape)
--- py/torch_tensorrt/fx/lower.py 2022-08-12 19:16:11.712868 +0000
+++ py/torch_tensorrt/fx/lower.py 2022-08-12 19:19:58.801249 +0000
@@ -77,13 +77,11 @@
lower_setting: LowerSetting
timing_cache_manager: TimingCacheManager
@classmethod
def create(cls, lower_setting):
- timing_cache_manager = TimingCacheManager(
- lower_setting.timing_cache_prefix, lower_setting.save_timing_cache
- )
+ timing_cache_manager = TimingCacheManager(lower_setting.timing_cache_prefix, lower_setting.save_timing_cache)
return LowerTrtInterpreter(lower_setting, timing_cache_manager)
def __call__(self, mod, input, split_name) -> TRTInterpreterResult:
assert self.lower_setting.input_specs, "Can't find input specs for lowering!"
logger.info(f"{split_name=} {self.lower_setting.input_specs=}")
@@ -103,13 +101,11 @@
interpreter = TRTInterpreter(
mod,
input_specs=self.lower_setting.input_specs,
explicit_batch_dimension=self.lower_setting.explicit_batch_dimension,
explicit_precision=self.lower_setting.explicit_precision,
- logger_level=trt.Logger.VERBOSE
- if self.lower_setting.verbose_log
- else trt.Logger.WARNING,
+ logger_level=trt.Logger.VERBOSE if self.lower_setting.verbose_log else trt.Logger.WARNING,
)
interp_result: TRTInterpreterResult = interpreter.run(
max_batch_size=self.lower_setting.max_batch_size,
max_workspace_size=self.lower_setting.max_workspace_size,
@@ -129,13 +125,11 @@
self.timing_cache_manager.update_timing_cache(split_name, timing_cache)
return interp_result
-def default_split_function(
- model: fx.GraphModule, inputs: Input, lower_setting: LowerSetting
-) -> SplitResult:
+def default_split_function(model: fx.GraphModule, inputs: Input, lower_setting: LowerSetting) -> SplitResult:
splitter_setting = TRTSplitterSetting()
splitter_setting.use_implicit_batch_dim = not lower_setting.explicit_batch_dimension
splitter_setting.min_acc_module_size = lower_setting.min_acc_module_size
splitter = TRTSplitter(model, inputs, settings=splitter_setting)
splitter.node_support_preview()
@@ -147,13 +141,11 @@
def default_lower_pass(
create_trt_interpreter: Callable[[LowerSetting], LowerTrtInterpreter],
) -> PassFunc:
- def lower_pass(
- mod: nn.Module, input: Input, lower_setting: LowerSetting, module_name: str
- ) -> nn.Module:
+ def lower_pass(mod: nn.Module, input: Input, lower_setting: LowerSetting, module_name: str) -> nn.Module:
"""
Create a module transformation pass which lowers an `fx.GraphModule` into a
`TRTModule`
"""
interpreter = create_trt_interpreter(lower_setting)
@@ -223,21 +215,13 @@
inputs: Input,
additional_inputs: Optional[Input] = None,
) -> nn.Module:
module.eval()
- if (
- self.lower_pass_manager_builder.lower_setting.lower_precision
- == LowerPrecision.FP16
- ):
+ if self.lower_pass_manager_builder.lower_setting.lower_precision == LowerPrecision.FP16:
module.half()
- inputs = tuple(
- x.half() if x is not None and x.dtype == torch.float32 else x
- for x in inputs
- )
- pm = self.lower_pass_manager_builder.build_trt_lower_pipeline(
- inputs, additional_inputs
- )
+ inputs = tuple(x.half() if x is not None and x.dtype == torch.float32 else x for x in inputs)
+ pm = self.lower_pass_manager_builder.build_trt_lower_pipeline(inputs, additional_inputs)
lower_result = pm(module)
return lower_result
--- py/torch_tensorrt/fx/passes/lower_pass_manager_builder.py 2022-08-12 19:16:11.712868 +0000
+++ py/torch_tensorrt/fx/passes/lower_pass_manager_builder.py 2022-08-12 19:19:59.035315 +0000
@@ -35,23 +35,17 @@
# >>> with FUSE_PASSES_POST_OBSERVER.add(print_module_and_input):
# >>> # print_module_and_input will be called right after the fuse passes
# >>> lower(module, sample_input)
# Observer for the model after the fuse passes.
-FUSE_PASSES_POST_OBSERVER: Observer[Callable[[nn.Module, Input], None]] = Observer(
- "FUSE_PASSES_POST_OBSERVER"
-)
+FUSE_PASSES_POST_OBSERVER: Observer[Callable[[nn.Module, Input], None]] = Observer("FUSE_PASSES_POST_OBSERVER")
# Observer for the TRT split submodules before lowering
-LOWER_SPLIT_PRE_OBSERVER: Observer[Callable[[str, nn.Module, Input], None]] = Observer(
- "LOWER_SPLIT_PRE_OBSERVER"
-)
+LOWER_SPLIT_PRE_OBSERVER: Observer[Callable[[str, nn.Module, Input], None]] = Observer("LOWER_SPLIT_PRE_OBSERVER")
# Observer for the TRT split submodules after lowering
-LOWER_SPLIT_POST_OBSERVER: Observer[Callable[[str, nn.Module, Input], None]] = Observer(
- "LOWER_SPLIT_POST_OBSERVER"
-)
+LOWER_SPLIT_POST_OBSERVER: Observer[Callable[[str, nn.Module, Input], None]] = Observer("LOWER_SPLIT_POST_OBSERVER")
# ----------------------------------------------------------------------
def wrapper(fn: Callable, input) -> Callable:
@wraps(fn)
@@ -103,22 +97,16 @@
passes.append(wrapper(p, self._input))
for p in self.lower_setting.lower_basic_fuse_pass.passes:
passes.append(wrapper(p, self._input))
passes.append(inplace_wrapper(common_subexpression_elimination))
- passes.append(
- inplace_wrapper(lambda m: FUSE_PASSES_POST_OBSERVER.observe(m, self._input))
- )
+ passes.append(inplace_wrapper(lambda m: FUSE_PASSES_POST_OBSERVER.observe(m, self._input)))
return PassManager.build_from_passlist(passes)
def _split_pass(self) -> PassManager:
- passes = [
- partial(
- self._split_func, inputs=self._input, lower_setting=self.lower_setting
- )
- ]
+ passes = [partial(self._split_func, inputs=self._input, lower_setting=self.lower_setting)]
passes.append(
inplace_wrapper(
lambda split_result: remove_duplicate_output_args(
split_result.split_module, split_result.submodule_inputs.keys()
)
@@ -152,21 +140,15 @@
lowering_start_time = datetime.datetime.now()
self.lower_setting.input_specs = generate_input_specs(
submod_inputs,
self.lower_setting,
- additional_submodule_inputs[submod_name]
- if additional_submodule_inputs
- else None,
+ additional_submodule_inputs[submod_name] if additional_submodule_inputs else None,
)
- lowered_module = self._lower_func(
- submod, submod_inputs, self.lower_setting, submod_name
- )
+ lowered_module = self._lower_func(submod, submod_inputs, self.lower_setting, submod_name)
setattr(split_result.split_module, submod_name, lowered_module)
- LOWER_SPLIT_POST_OBSERVER.observe(
- submod_name, lowered_module, submod_inputs
- )
+ LOWER_SPLIT_POST_OBSERVER.observe(submod_name, lowered_module, submod_inputs)
_LOGGER.info(
f"Lowering submodule {submod_name} elapsed time {datetime.datetime.now() - lowering_start_time}"
)
return split_result.split_module
@@ -184,28 +166,22 @@
# Only acc submodules will be lowered.
if not submod_name.startswith(split_result.non_acc_submodule_prefix):
_LOGGER.info(f"Now lowering submodule {submod_name}")
lowering_start_time = datetime.datetime.now()
- lowered_module = self._lower_func(
- submod, submod_inputs, self.lower_setting, submod_name
- )
+ lowered_module = self._lower_func(submod, submod_inputs, self.lower_setting, submod_name)
setattr(split_result.split_module, submod_name, lowered_module)
- LOWER_SPLIT_POST_OBSERVER.observe(
- submod_name, lowered_module, submod_inputs
- )
+ LOWER_SPLIT_POST_OBSERVER.observe(submod_name, lowered_module, submod_inputs)
_LOGGER.info(
f"Lowering submodule {submod_name} elapsed time {datetime.datetime.now() - lowering_start_time}"
)
return split_result.split_module
return PassManager.build_from_passlist([lower_func])
- def build_trt_lower_pipeline(
- self, input: Input, additional_input: Optional[Input] = None
- ) -> PassManager:
+ def build_trt_lower_pipeline(self, input: Input, additional_input: Optional[Input] = None) -> PassManager:
self._input = input
self._additional_input = additional_input
passes = []
passes.append(self._const_fold_pass())
@@ -214,13 +190,11 @@
passes.append(self._trt_lower_pass())
pm = PassManager.build_from_passlist(passes)
return pm
- def build_default_lower_pipeline(
- self, input: Input, additional_input: Optional[Input] = None
- ) -> PassManager:
+ def build_default_lower_pipeline(self, input: Input, additional_input: Optional[Input] = None) -> PassManager:
self._input = input
self._additional_input = additional_input
passes = []
passes.append(self._const_fold_pass())
--- py/torch_tensorrt/fx/test/converters/acc_op/test_avgpool.py 2022-08-12 19:16:11.712868 +0000
+++ py/torch_tensorrt/fx/test/converters/acc_op/test_avgpool.py 2022-08-12 19:19:59.232592 +0000
@@ -27,13 +27,11 @@
count_include_pad=True,
):
class TestModule(torch.nn.Module):
def __init__(self):
super().__init__()
- self.avg_pool = torch.nn.AvgPool1d(
- kernel_size, stride, padding, ceil_mode, count_include_pad
- )
+ self.avg_pool = torch.nn.AvgPool1d(kernel_size, stride, padding, ceil_mode, count_include_pad)
def forward(self, x):
return self.avg_pool(x)
inputs = [torch.randn(1, 3, 224)]
@@ -60,13 +58,11 @@
count_include_pad=True,
):
class TestModule(torch.nn.Module):
def __init__(self):
super().__init__()
- self.avg_pool = torch.nn.AvgPool1d(
- kernel_size, stride, padding, ceil_mode, count_include_pad
- )
+ self.avg_pool = torch.nn.AvgPool1d(kernel_size, stride, padding, ceil_mode, count_include_pad)
def forward(self, x):
return self.avg_pool(x)
input_specs = [
@@ -75,13 +71,11 @@
dtype=torch.float32,
shape_ranges=[((1, 3, 3), (3, 3, 3), (3, 3, 3))],
),
]
- self.run_test_with_dynamic_shape(
- TestModule(), input_specs, expected_ops={acc_ops.avg_pool1d}
- )
+ self.run_test_with_dynamic_shape(TestModule(), input_specs, expected_ops={acc_ops.avg_pool1d})
def test_avg_pool2d_with_dynamic_shape_four_dimensions(
self,
test_name="default",
kernel_size=1,
@@ -112,13 +106,11 @@
dtype=torch.float32,
shape_ranges=[((1, 1, 1, 1), (3, 3, 3, 3), (5, 5, 5, 5))],
),
]
- self.run_test_with_dynamic_shape(
- TestModule(), input_specs, expected_ops={acc_ops.avg_pool2d}
- )
+ self.run_test_with_dynamic_shape(TestModule(), input_specs, expected_ops={acc_ops.avg_pool2d})
@parameterized.expand(
[
("default", 1),
("kernal_size", 3),
@@ -254,12 +246,10 @@
dtype=torch.float32,
shape_ranges=[((1, 1, 1, 1), (3, 3, 3, 3), (5, 5, 5, 5))],
),
]
- self.run_test_with_dynamic_shape(
- TestModule(), input_specs, expected_ops={acc_ops.avg_pool2d}
- )
+ self.run_test_with_dynamic_shape(TestModule(), input_specs, expected_ops={acc_ops.avg_pool2d})
if __name__ == "__main__":
run_tests()
--- py/torch_tensorrt/fx/test/converters/acc_op/test_batchnorm.py 2022-08-12 19:16:11.712868 +0000
+++ py/torch_tensorrt/fx/test/converters/acc_op/test_batchnorm.py 2022-08-12 19:19:59.271217 +0000
@@ -32,13 +32,11 @@
dtype=torch.float32,
shape_ranges=[((2, 3, 5), (6, 3, 5), (10, 3, 5))],
),
]
- self.run_test_with_dynamic_shape(
- TestModule(), input_specs, expected_ops={acc_ops.batch_norm}
- )
+ self.run_test_with_dynamic_shape(TestModule(), input_specs, expected_ops={acc_ops.batch_norm})
def test_batchnorm_with_dynamic_shape(self):
class TestModule(torch.nn.Module):
def __init__(self):
super().__init__()
@@ -53,13 +51,11 @@
dtype=torch.float32,
shape_ranges=[((1, 3, 1, 1), (1, 3, 5, 5), (2, 3, 10, 10))],
),
]
- self.run_test_with_dynamic_shape(
- TestModule(), input_specs, expected_ops={acc_ops.batch_norm}
- )
+ self.run_test_with_dynamic_shape(TestModule(), input_specs, expected_ops={acc_ops.batch_norm})
# Testing with shape=(-1, -1, -1, -1) results in AssertionError: Channel dim can't be dynamic for batch norm.
if __name__ == "__main__":
--- py/torch_tensorrt/fx/test/converters/acc_op/test_clamp.py 2022-08-12 19:16:11.712868 +0000
+++ py/torch_tensorrt/fx/test/converters/acc_op/test_clamp.py 2022-08-12 19:19:59.400355 +0000
@@ -51,12 +51,10 @@
dtype=torch.float32,
shape_ranges=[((1, 1, 3, 3), (3, 3, 3, 3), (5, 5, 3, 3))],
),
]
- self.run_test_with_dynamic_shape(
- TestModule(), input_specs, expected_ops={acc_ops.clamp}
- )
+ self.run_test_with_dynamic_shape(TestModule(), input_specs, expected_ops={acc_ops.clamp})
if __name__ == "__main__":
run_tests()
--- py/torch_tensorrt/fx/test/converters/acc_op/test_convolution.py 2022-08-12 19:16:11.712868 +0000
+++ py/torch_tensorrt/fx/test/converters/acc_op/test_convolution.py 2022-08-12 19:19:59.500742 +0000
@@ -27,13 +27,11 @@
bias=True,
):
class TestModule(torch.nn.Module):
def __init__(self):
super().__init__()
- self.conv = torch.nn.Conv1d(
- 3, 6, kernel_size, stride, padding, dilation, groups, bias
- )
+ self.conv = torch.nn.Conv1d(3, 6, kernel_size, stride, padding, dilation, groups, bias)
def forward(self, x):
return self.conv(x)
inputs = [torch.randn(1, 3, 32)]
@@ -60,13 +58,11 @@
bias=True,
):
class TestModule(torch.nn.Module):
def __init__(self):
super().__init__()
- self.conv = torch.nn.Conv1d(
- 3, 6, kernel_size, stride, padding, dilation, groups, bias
- )
+ self.conv = torch.nn.Conv1d(3, 6, kernel_size, stride, padding, dilation, groups, bias)
def forward(self, x):
return self.conv(x)
input_specs = [
@@ -75,13 +71,11 @@
dtype=torch.float32,
shape_ranges=[((1, 3, 3), (3, 3, 3), (5, 3, 3))],
),
]
- self.run_test_with_dynamic_shape(
- TestModule(), input_specs, expected_ops={acc_ops.conv1d}
- )
+ self.run_test_with_dynamic_shape(TestModule(), input_specs, expected_ops={acc_ops.conv1d})
@parameterized.expand(
[
("default", 1),
param("no_bias", 1, bias=False),
@@ -102,13 +96,11 @@
bias=True,
):
class TestModule(torch.nn.Module):
def __init__(self):
super().__init__()
- self.conv = torch.nn.Conv2d(
- 3, 6, kernel_size, stride, padding, dilation, groups, bias
- )
+ self.conv = torch.nn.Conv2d(3, 6, kernel_size, stride, padding, dilation, groups, bias)
def forward(self, x):
return self.conv(x)
inputs = [torch.randn(1, 3, 32, 32)]
@@ -131,13 +123,11 @@
shape=(-1, 3, -1, -1),
dtype=torch.float32,
shape_ranges=[((1, 3, 1, 1), (1, 3, 4, 4), (32, 3, 128, 128))],
),
]
- self.run_test_with_dynamic_shape(
- TestModule(), input_specs, expected_ops={acc_ops.conv2d}
- )
+ self.run_test_with_dynamic_shape(TestModule(), input_specs, expected_ops={acc_ops.conv2d})
@parameterized.expand(
[
("default", 1),
param("no_bias", 1, bias=False),
@@ -158,13 +148,11 @@
bias=True,
):
class TestModule(torch.nn.Module):
def __init__(self):
super().__init__()
- self.conv = torch.nn.Conv3d(
- 3, 6, kernel_size, stride, padding, dilation, groups, bias
- )
+ self.conv = torch.nn.Conv3d(3, 6, kernel_size, stride, padding, dilation, groups, bias)
def forward(self, x):
return self.conv(x)
inputs = [torch.randn(1, 3, 32, 32, 32)]
@@ -187,12 +175,10 @@
shape=(-1, 3, -1, -1, -1),
dtype=torch.float32,
shape_ranges=[((1, 3, 1, 1, 1), (1, 3, 4, 4, 4), (8, 3, 32, 32, 32))],
),
]
- self.run_test_with_dynamic_shape(
- TestModule(), input_specs, expected_ops={acc_ops.conv3d}
- )
+ self.run_test_with_dynamic_shape(TestModule(), input_specs, expected_ops={acc_ops.conv3d})
if __name__ == "__main__":
run_tests()
--- py/torch_tensorrt/fx/test/converters/acc_op/test_gelu.py 2022-08-12 19:16:11.712868 +0000
+++ py/torch_tensorrt/fx/test/converters/acc_op/test_gelu.py 2022-08-12 19:19:59.662120 +0000
@@ -5,13 +5,11 @@
import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops
from torch.testing._internal.common_utils import run_tests
from torch_tensorrt.fx.tools.common_fx2trt import AccTestCase, InputTensorSpec
-@unittest.skip(
- reason="Could not find CustomGeluPluginDynamic. Enable it once we upgrade TRT to 8.4"
-)
+@unittest.skip(reason="Could not find CustomGeluPluginDynamic. Enable it once we upgrade TRT to 8.4")
class TestGELU(AccTestCase):
def test_gelu(self):
class TestModule(nn.Module):
def forward(self, x):
return nn.functional.gelu(x)
@@ -34,13 +32,11 @@
shape=(-1, -1, -1),
dtype=torch.float32,
shape_ranges=[((1, 1, 1), (1, 2, 3), (3, 3, 3))],
),
]
- self.run_test_with_dynamic_shape(
- TestModule(), input_specs, expected_ops={acc_ops.gelu}
- )
+ self.run_test_with_dynamic_shape(TestModule(), input_specs, expected_ops={acc_ops.gelu})
def test_gelu_with_dynamic_shape_four_dimensions(self):
class TestModule(nn.Module):
def forward(self, x):
return nn.functional.gelu(x)
@@ -51,12 +47,10 @@
dtype=torch.float32,
shape_ranges=[((1, 1, 1, 1), (1, 2, 3, 3), (3, 3, 3, 3))],
),
]
- self.run_test_with_dynamic_shape(
- TestModule(), input_specs, expected_ops={acc_ops.gelu}
- )
+ self.run_test_with_dynamic_shape(TestModule(), input_specs, expected_ops={acc_ops.gelu})
if __name__ == "__main__":
run_tests()
--- py/torch_tensorrt/fx/test/converters/acc_op/test_interpolate.py 2022-08-12 19:16:11.712868 +0000
+++ py/torch_tensorrt/fx/test/converters/acc_op/test_interpolate.py 2022-08-12 19:19:59.880992 +0000
@@ -131,12 +131,10 @@
dtype=torch.float32,
shape_ranges=[((1, 1, 1, 1), (1, 2, 3, 3), (3, 3, 3, 3))],
),
]
- self.run_test_with_dynamic_shape(
- Interpolate(), input_specs, expected_ops={acc_ops.interpolate}
- )
+ self.run_test_with_dynamic_shape(Interpolate(), input_specs, expected_ops={acc_ops.interpolate})
if __name__ == "__main__":
run_tests()
--- py/torch_tensorrt/fx/test/converters/acc_op/test_matmul.py 2022-08-12 19:16:11.712868 +0000
+++ py/torch_tensorrt/fx/test/converters/acc_op/test_matmul.py 2022-08-12 19:20:00.162992 +0000
@@ -71,15 +71,11 @@
class MatMul(nn.Module):
def forward(self, input, other):
return torch.matmul(input, other)
inputs = [torch.randn(*input_shape), torch.randn(*other_shape)]
- test_implicit_batch_dim = (
- input_shape[0] == other_shape[0]
- and len(input_shape) > 2
- and len(other_shape) > 2
- )
+ test_implicit_batch_dim = input_shape[0] == other_shape[0] and len(input_shape) > 2 and len(other_shape) > 2
self.run_test(
MatMul(),
inputs,
expected_ops={acc_ops.matmul},
test_implicit_batch_dim=test_implicit_batch_dim,
@@ -106,12 +102,10 @@
dtype=torch.float32,
shape_ranges=[((1, 1, 3, 3), (9, 4, 3, 3), (9, 4, 3, 3))],
),
]
- self.run_test_with_dynamic_shape(
- Matmul(), input_specs, expected_ops={acc_ops.matmul}
- )
+ self.run_test_with_dynamic_shape(Matmul(), input_specs, expected_ops={acc_ops.matmul})
if __name__ == "__main__":
run_tests()
--- py/torch_tensorrt/fx/test/converters/acc_op/test_max.py 2022-08-12 19:16:11.712868 +0000
+++ py/torch_tensorrt/fx/test/converters/acc_op/test_max.py 2022-08-12 19:20:00.274826 +0000
@@ -102,13 +102,11 @@
dtype=torch.float32,
shape_ranges=[((1, 1, 5, 5), (2, 3, 5, 5), (2, 3, 5, 5))],
),
]
- self.run_test_with_dynamic_shape(
- MaxDimReduce(), input_specs, expected_ops={acc_ops.max_dim_reduce}
- )
+ self.run_test_with_dynamic_shape(MaxDimReduce(), input_specs, expected_ops={acc_ops.max_dim_reduce})
def test_max_full_reduce(
self,
):
class MaxFullReduce(torch.nn.Module):
@@ -124,13 +122,11 @@
dtype=torch.float32,
shape_ranges=[((1, 1, 5, 5), (2, 3, 5, 5), (2, 3, 5, 5))],
),
]
- self.run_test_with_dynamic_shape(
- MaxFullReduce(), input_specs, expected_ops={acc_ops.max_full_reduce}
- )
+ self.run_test_with_dynamic_shape(MaxFullReduce(), input_specs, expected_ops={acc_ops.max_full_reduce})
def test_max_method(self):
class MaxMethod(torch.nn.Module):
def __init__(self):
super().__init__()
@@ -149,12 +145,10 @@
dtype=torch.float32,
shape_ranges=[((1, 1, 5, 5), (2, 3, 5, 5), (2, 3, 5, 5))],
),
]
- self.run_test_with_dynamic_shape(
- MaxMethod(), input_specs, expected_ops={acc_ops.maximum}
- )
+ self.run_test_with_dynamic_shape(MaxMethod(), input_specs, expected_ops={acc_ops.maximum})
if __name__ == "__main__":
run_tests()
--- py/torch_tensorrt/fx/test/converters/acc_op/test_min.py 2022-08-12 19:16:11.712868 +0000
+++ py/torch_tensorrt/fx/test/converters/acc_op/test_min.py 2022-08-12 19:20:00.440385 +0000
@@ -101,13 +101,11 @@
dtype=torch.float32,
shape_ranges=[((1, 1, 5, 5), (2, 3, 5, 5), (2, 3, 5, 5))],
),
]
- self.run_test_with_dynamic_shape(
- MinDimReduce(), input_specs, expected_ops={acc_ops.min_dim_reduce}
- )
+ self.run_test_with_dynamic_shape(MinDimReduce(), input_specs, expected_ops={acc_ops.min_dim_reduce})
def test_min_full_reduce(
self,
):
class MinFullReduce(torch.nn.Module):
@@ -123,13 +121,11 @@
dtype=torch.float32,
shape_ranges=[((1, 1, 5, 5), (2, 3, 5, 5), (2, 3, 5, 5))],
),
]
- self.run_test_with_dynamic_shape(
- MinFullReduce(), input_specs, expected_ops={acc_ops.min_full_reduce}
- )
+ self.run_test_with_dynamic_shape(MinFullReduce(), input_specs, expected_ops={acc_ops.min_full_reduce})
def test_min_method(self):
class MinMethod(torch.nn.Module):
def __init__(self):
super().__init__()
@@ -148,12 +144,10 @@
dtype=torch.float32,
shape_ranges=[((1, 1, 5, 5), (2, 3, 5, 5), (2, 3, 5, 5))],
),
]
- self.run_test_with_dynamic_shape(
- MinMethod(), input_specs, expected_ops={acc_ops.minimum}
- )
+ self.run_test_with_dynamic_shape(MinMethod(), input_specs, expected_ops={acc_ops.minimum})
if __name__ == "__main__":
run_tests()
--- py/torch_tensorrt/fx/test/converters/acc_op/test_narrow.py 2022-08-12 19:16:11.712868 +0000
+++ py/torch_tensorrt/fx/test/converters/acc_op/test_narrow.py 2022-08-12 19:20:00.490650 +0000
@@ -23,13 +23,11 @@
dtype=torch.float32,
shape_ranges=[((1, 3, 5, 5), (2, 3, 5, 5), (2, 3, 5, 5))],
),
]
- self.run_test_with_dynamic_shape(
- Narrow(), input_specs, expected_ops={acc_ops.slice_tensor}
- )
+ self.run_test_with_dynamic_shape(Narrow(), input_specs, expected_ops={acc_ops.slice_tensor})
class TestNarrowConverter(AccTestCase):
@parameterized.expand(
[
--- py/torch_tensorrt/fx/converters/acc_ops_converters.py 2022-08-12 19:16:11.708868 +0000
+++ py/torch_tensorrt/fx/converters/acc_ops_converters.py 2022-08-12 19:20:00.578833 +0000
@@ -34,14 +34,11 @@
kwargs: Dict[str, Argument],
name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
input_val = kwargs["input"]
if not isinstance(input_val, TRTTensor):
- raise RuntimeError(
- f"Conv received input {input_val} that is not part "
- "of the TensorRT region!"
- )
+ raise RuntimeError(f"Conv received input {input_val} that is not part " "of the TensorRT region!")
# Process 1d input with unsqueeze -> conv2d -> squeeze to calculated conv1d
unsqueeze_layer = network.add_shuffle(input=input_val)
unsqueeze_layer.reshape_dims = tuple([*input_val.shape, 1])
set_layer_name(unsqueeze_layer, target, name + "_unsqueeze")
@@ -52,13 +49,11 @@
# for now we'll assume bias is constant Tensor or None,
# and bias being ITensor is not supported in TensorRT api
# right now
if kwargs["bias"] is not None and not isinstance(kwargs["bias"], torch.Tensor):
- raise RuntimeError(
- f"linear {name} has bias of type {type(kwargs['bias'])}, Expect Optional[Tenosr]"
- )
+ raise RuntimeError(f"linear {name} has bias of type {type(kwargs['bias'])}, Expect Optional[Tenosr]")
bias = to_numpy(kwargs["bias"]) # type: ignore[arg-type]
if bias is not None:
bias = bias[None]
weight = kwargs["weight"]
@@ -82,13 +77,11 @@
)
layer.set_input(1, weight)
else:
if not isinstance(kwargs["weight"], torch.Tensor):
- raise RuntimeError(
- f"linear {name} has weight of type {type(kwargs['weight'])}, Expect Optional[Tenosr]"
- )
+ raise RuntimeError(f"linear {name} has weight of type {type(kwargs['weight'])}, Expect Optional[Tenosr]")
weight = to_numpy(weight)
weight = np.expand_dims(weight, -1)
layer = network.add_convolution_nd(
input=input_val,
num_output_maps=weight.shape[0],
@@ -126,25 +119,20 @@
name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
input_val = kwargs["input"]
if not isinstance(input_val, TRTTensor):
- raise RuntimeError(
- f"Conv received input {input_val} that is not part "
- "of the TensorRT region!"
- )
+ raise RuntimeError(f"Conv received input {input_val} that is not part " "of the TensorRT region!")
if has_dynamic_shape(input_val.shape):
assert input_val.shape[1] != -1, "Channel dim can't be dynamic for convolution."
# for now we'll assume bias is constant Tensor or None,
# and bias being ITensor is not supported in TensorRT api
# right now
if kwargs["bias"] is not None and not isinstance(kwargs["bias"], torch.Tensor):
- raise RuntimeError(
- f"linear {name} has bias of type {type(kwargs['bias'])}, Expect Optional[Tenosr]"
- )
+ raise RuntimeError(f"linear {name} has bias of type {type(kwargs['bias'])}, Expect Optional[Tenosr]")
bias = to_numpy(kwargs["bias"]) # type: ignore[arg-type]
if network.has_explicit_precision:
weight = get_trt_tensor(network, kwargs["weight"], f"{name}_weight")
weight_shape = tuple(kwargs["weight"].shape) # type: ignore[union-attr]
@@ -160,13 +148,11 @@
)
layer.set_input(1, weight)
else:
if not isinstance(kwargs["weight"], torch.Tensor):
- raise RuntimeError(
- f"linear {name} has weight of type {type(kwargs['weight'])}, Expect Optional[Tenosr]"
- )
+ raise RuntimeError(f"linear {name} has weight of type {type(kwargs['weight'])}, Expect Optional[Tenosr]")
weight = to_numpy(kwargs["weight"])
layer = network.add_convolution_nd(
input=input_val,
num_output_maps=weight.shape[0],
kernel_shape=weight.shape[2:],
@@ -194,27 +180,20 @@
name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
input_val = kwargs["input"]
if not isinstance(input_val, TRTTensor):
- raise RuntimeError(
- f"Transpose conv received input {input_val} that is not part "
- "of the TensorRT region!"
- )
+ raise RuntimeError(f"Transpose conv received input {input_val} that is not part " "of the TensorRT region!")
if has_dynamic_shape(input_val.shape):
- assert (
- input_val.shape[1] != -1
- ), "Channel dim can't be dynamic for transpose convolution."
+ assert input_val.shape[1] != -1, "Channel dim can't be dynamic for transpose convolution."
# for now we'll assume bias is constant Tensor or None,
# and bias being ITensor is not supported in TensorRT api
# right now
if kwargs["bias"] is not None and not isinstance(kwargs["bias"], torch.Tensor):
- raise RuntimeError(
- f"ConvTranspose {name} has bias of type {type(kwargs['bias'])}, Expect Optional[Tensor]"
- )
+ raise RuntimeError(f"ConvTranspose {name} has bias of type {type(kwargs['bias'])}, Expect Optional[Tensor]")
bias = to_numpy(kwargs["bias"]) # type: ignore[arg-type]
if network.has_explicit_precision:
weight = get_trt_tensor(network, kwargs["weight"], f"{name}_weight")
weight_shape = tuple(kwargs["weight"].shape) # type: ignore[union-attr]
@@ -232,13 +211,11 @@
)
layer.set_input(1, weight)
else:
if not isinstance(kwargs["weight"], torch.Tensor):
- raise RuntimeError(
- f"conv {name} has weight of type {type(kwargs['weight'])}, Expect Optional[Tensor]"
- )
+ raise RuntimeError(f"conv {name} has weight of type {type(kwargs['weight'])}, Expect Optional[Tensor]")
weight = to_numpy(kwargs["weight"])
# nn.ConvTranspose2d/3d weight size is (in_channels, out_channels/groups, kernel_0, kernel_1, [kernel_2])
layer = network.add_deconvolution_nd(
input=input_val,
num_output_maps=weight.shape[1] * kwargs["groups"],
@@ -270,29 +247,20 @@
mode = kwargs["mode"]
value = kwargs["value"] if kwargs["value"] is not None else 0
rank = len(input_val.shape) # type: ignore[union-attr]
if not isinstance(input_val, TRTTensor):
- raise RuntimeError(
- f"pad received input {input_val} that is not part "
- "of the TensorRT region!"
- )
+ raise RuntimeError(f"pad received input {input_val} that is not part " "of the TensorRT region!")
if mode != "constant":
- raise RuntimeError(
- f"Currently we only support constant mode for pad, got {mode}."
- )
+ raise RuntimeError(f"Currently we only support constant mode for pad, got {mode}.")
if len(pad) / 2 > rank:
- raise RuntimeError(
- f"Trying to pad last {len(pad) / 2} dimension but the input only has {rank} dimension."
- )
+ raise RuntimeError(f"Trying to pad last {len(pad) / 2} dimension but the input only has {rank} dimension.")
if value != 0:
- raise RuntimeError(
- f"Currently we only support padding value of 0, got {value}."
- )
+ raise RuntimeError(f"Currently we only support padding value of 0, got {value}.")
if len(pad) > 4:
raise RuntimeError("Currently we only support padding last two dimensions.")
pre_padding = tuple(pad[len(pad) - i - 2] for i in range(0, len(pad), 2))
@@ -320,38 +288,28 @@
mode = kwargs["mode"]
value = kwargs["value"] if kwargs["value"] is not None else 0
rank = len(input_val.shape) # type: ignore[union-attr]
if not isinstance(input_val, TRTTensor):
- raise RuntimeError(
- f"pad received input {input_val} that is not part "
- "of the TensorRT region!"
- )
+ raise RuntimeError(f"pad received input {input_val} that is not part " "of the TensorRT region!")
if mode != "constant":
- raise RuntimeError(
- f"Currently we only support constant mode for pad, got {mode}."
- )
+ raise RuntimeError(f"Currently we only support constant mode for pad, got {mode}.")
if len(pad) / 2 > rank:
- raise RuntimeError(
- f"Trying to pad last {len(pad) / 2} dimension but the input only has {rank} dimension."
- )
+ raise RuntimeError(f"Trying to pad last {len(pad) / 2} dimension but the input only has {rank} dimension.")
# cast value to TRTensor
dt = torch_dtype_from_trt(input_val.dtype)
value = 0 if value == None else value
- value_const = get_trt_tensor(
- network, torch.tensor([value], dtype=dt), f"{name}_value"
- )
+ value_const = get_trt_tensor(network, torch.tensor([value], dtype=dt), f"{name}_value")
input_shape = input_val.shape
pre_start = tuple(i - 1 for i in input_shape)
prefix_len = len(input_shape) - len(pad) // 2
pre_shape = tuple(
- input_shape[i] + (pad[-(i - prefix_len) * 2 - 2] if i >= prefix_len else 0)
- for i in range(0, len(input_shape))
+ input_shape[i] + (pad[-(i - prefix_len) * 2 - 2] if i >= prefix_len else 0) for i in range(0, len(input_shape))
)
pre_stride = [-1] * len(input_shape)
layer = network.add_slice(
input_val,
@@ -374,12 +332,11 @@
transpose_output = layer.get_output(0)
shape = transpose_output.shape
post_start = tuple([0] * len(shape))
post_shape = tuple(
- shape[i] + (pad[-(i - prefix_len) * 2 - 1] if i >= prefix_len else 0)
- for i in range(0, len(shape))
+ shape[i] + (pad[-(i - prefix_len) * 2 - 1] if i >= prefix_len else 0) for i in range(0, len(shape))
)
post_stride = tuple([1] * len(shape))
layer = network.add_slice(transpose_output, post_start, post_shape, post_stride)
layer.set_input(4, value_const)
@@ -397,22 +354,15 @@
name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
input_val = kwargs["input"]
if not isinstance(input_val, TRTTensor):
- raise RuntimeError(
- f"flatten received input {input_val} that is not part "
- "of the TensorRT region!"
- )
+ raise RuntimeError(f"flatten received input {input_val} that is not part " "of the TensorRT region!")
num_dims = len(input_val.shape) + (1 if network.has_implicit_batch_dimension else 0)
- start_dim = get_positive_dim(
- cast(int, kwargs["start_dim"] if "start_dim" in kwargs else 0), num_dims
- )
- end_dim = get_positive_dim(
- cast(int, kwargs["end_dim"] if "end_dim" in kwargs else -1), num_dims
- )
+ start_dim = get_positive_dim(cast(int, kwargs["start_dim"] if "start_dim" in kwargs else 0), num_dims)
+ end_dim = get_positive_dim(cast(int, kwargs["end_dim"] if "end_dim" in kwargs else -1), num_dims)
if network.has_implicit_batch_dimension:
assert start_dim != 0, "Can't flatten batch dimension when it's implicit."
start_dim -= 1
end_dim -= 1
@@ -511,24 +461,18 @@
kwargs: Dict[str, Argument],
name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
input_t = kwargs["input"]
if type(input_t) == torch.nn.Parameter or type(input_t) == torch.Tensor:
- if (
- not has_dynamic_shape(input_t.shape)
- and network.has_implicit_batch_dimension
- ):
+ if not has_dynamic_shape(input_t.shape) and network.has_implicit_batch_dimension:
return torch.Size((IMPLICIT_BATCH_DIM,) + tuple(input_t.shape))
return input_t.shape
# input_val = get_trt_tensor(network, input_t, f"{name}_input_t")
input_val = input_t
if not isinstance(input_val, TRTTensor):
- raise RuntimeError(
- f"size received input {input_val} that is not part "
- "of the TensorRT region!"
- )
+ raise RuntimeError(f"size received input {input_val} that is not part " "of the TensorRT region!")
if not has_dynamic_shape(input_val.shape):
if network.has_implicit_batch_dimension:
return torch.Size((IMPLICIT_BATCH_DIM,) + tuple(input_val.shape))
return torch.Size(input_val.shape)
@@ -547,14 +491,11 @@
name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
input_val = kwargs["input"]
if not isinstance(input_val, TRTTensor):
- raise RuntimeError(
- f"size received input {input_val} that is not part "
- "of the TensorRT region!"
- )
+ raise RuntimeError(f"size received input {input_val} that is not part " "of the TensorRT region!")
if has_dynamic_shape(input_val.shape):
raise RuntimeError(f"numel does not support dynamic shapes.")
numel = np.prod(input_val.shape)
@@ -572,29 +513,20 @@
name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
input_val = kwargs["input"]
if not isinstance(input_val, TRTTensor):
- raise RuntimeError(
- f"BatchNorm2d received input {input_val} that is not part "
- "of the TensorRT region!"
- )
+ raise RuntimeError(f"BatchNorm2d received input {input_val} that is not part " "of the TensorRT region!")
if has_dynamic_shape(input_val.shape):
assert input_val.shape[1] != -1, "Channel dim can't be dynamic for batch norm."
- scale = cast(
- torch.Tensor, to_numpy(cast(torch.Tensor, kwargs["weight"]))
- ) / np.sqrt(
- cast(torch.Tensor, to_numpy(cast(torch.Tensor, kwargs["running_var"])))
- + cast(float, kwargs["eps"])
+ scale = cast(torch.Tensor, to_numpy(cast(torch.Tensor, kwargs["weight"]))) / np.sqrt(
+ cast(torch.Tensor, to_numpy(cast(torch.Tensor, kwargs["running_var"]))) + cast(float, kwargs["eps"])
)
- bias = (
- to_numpy(cast(torch.Tensor, kwargs["bias"]))
- - to_numpy(cast(torch.Tensor, kwargs["running_mean"])) * scale
- )
+ bias = to_numpy(cast(torch.Tensor, kwargs["bias"])) - to_numpy(cast(torch.Tensor, kwargs["running_mean"])) * scale
power = np.ones_like(scale)
# For BatchNorm1d, reshape 1d to 2d
output_shape = input_val.shape
if not network.has_implicit_batch_dimension and len(input_val.shape) < 4:
@@ -628,44 +560,33 @@
@tensorrt_converter(acc_ops.layer_norm)
def acc_ops_layer_norm(network, target, args, kwargs, name):
input_val = kwargs["input"]
if not isinstance(input_val, trt.tensorrt.ITensor):
- raise RuntimeError(
- f"LayerNorm received input {input_val} that is not part "
- "of the TensorRT region!"
- )
+ raise RuntimeError(f"LayerNorm received input {input_val} that is not part " "of the TensorRT region!")
gamma = kwargs["weight"].detach().cpu().float().numpy()
gamma_field = trt.PluginField("gamma", gamma, trt.PluginFieldType.FLOAT32)
beta = kwargs["bias"].detach().cpu().float().numpy()
beta_field = trt.PluginField("beta", beta, trt.PluginFieldType.FLOAT32)
- eps_field = trt.PluginField(
- "eps", np.array([kwargs["eps"]], dtype=np.float32), trt.PluginFieldType.FLOAT32
- )
+ eps_field = trt.PluginField("eps", np.array([kwargs["eps"]], dtype=np.float32), trt.PluginFieldType.FLOAT32)
try:
normalized_shape = np.array(kwargs["normalized_shape"], dtype=np.int32)
except TypeError:
_LOGGER.error("Unable to convert normalized_shape to a field, fall back to []")
normalized_shape = np.array([], dtype=np.int32)
- normalized_shape_filed = trt.PluginField(
- "normalized_shape", normalized_shape, trt.PluginFieldType.INT32
- )
- field_collection = trt.PluginFieldCollection(
- [gamma_field, beta_field, eps_field, normalized_shape_filed]
- )
+ normalized_shape_filed = trt.PluginField("normalized_shape", normalized_shape, trt.PluginFieldType.INT32)
+ field_collection = trt.PluginFieldCollection([gamma_field, beta_field, eps_field, normalized_shape_filed])
try:
if network.has_implicit_batch_dimension:
plugin = get_trt_plugin("layer_norm", field_collection, "1", "fx2trt")
else:
plugin = get_trt_plugin("LayerNormDynamic", field_collection, "1", "fx2trt")
except AssertionError:
- _LOGGER.error(
- "Unable to find layer norm plugin, fall back to TensorRT implementation."
- )
+ _LOGGER.error("Unable to find layer norm plugin, fall back to TensorRT implementation.")
return layer_norm(network, target, args, kwargs, name)
layer = network.add_plugin_v2([input_val], plugin)
layer.name = name
return layer.get_output(0)
@@ -678,14 +599,11 @@
name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
input_val = kwargs["input"]
if not isinstance(input_val, TRTTensor):
- raise RuntimeError(
- f"LayerNorm received input {input_val} that is not part "
- "of the TensorRT region!"
- )
+ raise RuntimeError(f"LayerNorm received input {input_val} that is not part " "of the TensorRT region!")
shape = kwargs["weight"].shape # type: ignore[union-attr]
broadcasted_shape = (1,) * (len(input_val.shape) - len(shape)) + shape
gamma = to_numpy(kwargs["weight"].reshape(*shape)) # type: ignore[union-attr]
beta = to_numpy(kwargs["bias"].reshape(*shape)) # type: ignore[union-attr]
@@ -694,13 +612,11 @@
axes = 0
for d in range(len(shape)):
axes |= 1 << (len(input_val.shape) - d - 1)
# E[x]
- mean_expected_layer = network.add_reduce(
- input_val, trt.ReduceOperation.AVG, axes, keep_dims=True
- )
+ mean_expected_layer = network.add_reduce(input_val, trt.ReduceOperation.AVG, axes, keep_dims=True)
set_layer_name(mean_expected_layer, target, f"{name}_mean_expected")
# X-E[x]
sub_trt = add_binary_elementwise_layer(
network,
@@ -722,13 +638,11 @@
pow_tensor.get_output(0),
trt.ElementWiseOperation.POW,
target,
f"{name}_pow_var",
)
- mean_trt_layer = network.add_reduce(
- pow_var, trt.ReduceOperation.AVG, axes, keep_dims=True
- )
+ mean_trt_layer = network.add_reduce(pow_var, trt.ReduceOperation.AVG, axes, keep_dims=True)
set_layer_name(mean_trt_layer, target, f"{name}_mean")
# Variance + eps
eps_tensor = network.add_constant(
(1,) * len(input_val.shape),
trt.Weights(np.ascontiguousarray([eps], dtype=np.float32)),
@@ -741,13 +655,11 @@
trt.ElementWiseOperation.SUM,
target,
f"{name}_add",
)
# SQRT((Var + eps))
- sqrt_trt = add_unary_layer(
- network, add_trt, trt.UnaryOperation.SQRT, target, f"{name}_sqrt"
- )
+ sqrt_trt = add_unary_layer(network, add_trt, trt.UnaryOperation.SQRT, target, f"{name}_sqrt")
# (x - E[x]) / sqrt((var + eps))
div_trt = add_binary_elementwise_layer(
network,
sub_trt,
sqrt_trt,
@@ -791,14 +703,11 @@
) -> Union[TRTTensor, Sequence[TRTTensor]]:
input_val = kwargs["input"]
input_ranks = len(input_val.shape) + (1 if network.has_implicit_batch_dimension else 0) # type: ignore[union-attr]
if not isinstance(input_val, TRTTensor):
- raise RuntimeError(
- f"softmax received input {input_val} that is not part "
- "of the TensorRT region!"
- )
+ raise RuntimeError(f"softmax received input {input_val} that is not part " "of the TensorRT region!")
# Used to get dim when dim is None. Copied from PyTorch softmax implementation.
def get_softmax_dim(ndim: int) -> int:
if ndim == 0 or ndim == 1 or ndim == 3:
ret = 0
@@ -832,13 +741,11 @@
) -> Union[TRTTensor, Sequence[TRTTensor]]:
input_t = kwargs["input"]
input_val = get_trt_tensor(network, input_t, f"{name}_input")
dims = tuple(cast(Sequence[int], kwargs["dims"]))
- n_input_dims = len(input_val.shape) + (
- 1 if network.has_implicit_batch_dimension else 0
- )
+ n_input_dims = len(input_val.shape) + (1 if network.has_implicit_batch_dimension else 0)
if len(dims) > n_input_dims:
assert not network.has_implicit_batch_dimension
layer = network.add_shuffle(input_val)
layer.name = f"{name}_reshape"
@@ -849,20 +756,16 @@
input_shape_layer.name = f"{name}_input_shape"
preceding_ones = network.add_constant(
(num_preceding_ones,),
np.ascontiguousarray([1] * num_preceding_ones, np.int32),
).get_output(0)
- reshape_layer = network.add_concatenation(
- [preceding_ones, input_shape_layer.get_output(0)]
- )
+ reshape_layer = network.add_concatenation([preceding_ones, input_shape_layer.get_output(0)])
reshape_layer.axis = 0
reshape_layer.name = f"{name}_reshape_dims"
layer.set_input(1, reshape_layer.get_output(0))
else:
- layer.reshape_dims = (1,) * (len(dims) - n_input_dims) + tuple(
- input_val.shape
- )
+ layer.reshape_dims = (1,) * (len(dims) - n_input_dims) + tuple(input_val.shape)
input_val = layer.get_output(0)
else:
dims = (1,) * (n_input_dims - len(dims)) + dims
if network.has_implicit_batch_dimension:
@@ -898,17 +801,15 @@
layer = network.add_slice(input_val, starts, shapes, strides)
layer.mode = trt.SliceMode.WRAP
set_layer_name(layer, target, name)
if has_dynamic_shape(input_val.shape): # type: ignore[union-attr]
- starts_tensor = network.add_constant(
- (len(dims),), np.ascontiguousarray([0] * len(dims), np.int32)
- ).get_output(0)
+ starts_tensor = network.add_constant((len(dims),), np.ascontiguousarray([0] * len(dims), np.int32)).get_output(
+ 0
+ )
if all(isinstance(d, int) for d in dims):
- dims_tensor = network.add_constant(
- (len(dims),), np.ascontiguousarray(dims, np.int32)
- ).get_output(0)
+ dims_tensor = network.add_constant((len(dims),), np.ascontiguousarray(dims, np.int32)).get_output(0)
else:
assert all(isinstance(d, TRTTensor) for d in dims)
concat_dims_layer = network.add_concatenation(inputs=dims)
concat_dims_layer.axis = 0
concat_dims_layer.name = f"{name}_tile_dim"
@@ -969,13 +870,11 @@
name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
input_val = kwargs["input"]
negative_slope = kwargs["negative_slope"]
operation_type = trt.ActivationType.LEAKY_RELU
- return add_activation_layer(
- network, input_val, operation_type, target, name, negative_slope
- )
+ return add_activation_layer(network, input_val, operation_type, target, name, negative_slope)
@tensorrt_converter(acc_ops.elu)
def acc_ops_elu(
network: TRTNetwork,
@@ -1243,51 +1142,40 @@
target: Target,
args: Tuple[Argument, ...],
kwargs: Dict[str, Argument],
name: str,
) -> TRTTensor:
- return add_reduce_layer(
- network, target, args, kwargs, trt.ReduceOperation.SUM, name
- )
+ return add_reduce_layer(network, target, args, kwargs, trt.ReduceOperation.SUM, name)
@tensorrt_converter(acc_ops.prod)
def acc_ops_prod(
network: TRTNetwork,
target: Target,
args: Tuple[Argument, ...],
kwargs: Dict[str, Argument],
name: str,
) -> TRTTensor:
- return add_reduce_layer(
- network, target, args, kwargs, trt.ReduceOperation.PROD, name
- )
+ return add_reduce_layer(network, target, args, kwargs, trt.ReduceOperation.PROD, name)
@tensorrt_converter(acc_ops.mean)
def acc_ops_mean(
network: TRTNetwork,
target: Target,
args: Tuple[Argument, ...],
kwargs: Dict[str, Argument],
name: str,
) -> TRTTensor:
- return add_reduce_layer(
- network, target, args, kwargs, trt.ReduceOperation.AVG, name
- )
+ return add_reduce_layer(network, target, args, kwargs, trt.ReduceOperation.AVG, name)
def add_acc_ops_full_reduce(network, target, args, kwargs, name, reduce_op):
input_val = kwargs["input"]
if not isinstance(input_val, TRTTensor):
- raise RuntimeError(
- f"max received input {input_val} that is not part "
- "of the TensorRT region!"
- )
- assert (
- not network.has_implicit_batch_dimension
- ), "Do not support max over all the elements for implicit batch."
+ raise RuntimeError(f"max received input {input_val} that is not part " "of the TensorRT region!")
+ assert not network.has_implicit_batch_dimension, "Do not support max over all the elements for implicit batch."
dim = range(len(input_val.shape))
layer = network.add_reduce(
input_val,
@@ -1307,25 +1195,21 @@
new_kwargs["largest"] = True
elif reduce_op == trt.ReduceOperation.MIN:
new_kwargs["largest"] = False
new_kwargs["sorted"] = False
- topk_out0, topk_out1 = acc_ops_topk(
- network, target, args, new_kwargs, name + "_topk"
- )
+ topk_out0, topk_out1 = acc_ops_topk(network, target, args, new_kwargs, name + "_topk")
topk_out0.name = f"{name}_topk0"
topk_out1.name = f"{name}_topk1"
if "keepdim" in new_kwargs and new_kwargs["keepdim"]:
return topk_out0, topk_out1
dim = new_kwargs["dim"]
if network.has_implicit_batch_dimension:
- assert (
- dim != 0
- ), "can't reduce on dim == 0 when network has implicit batch dimension"
+ assert dim != 0, "can't reduce on dim == 0 when network has implicit batch dimension"
# we remove the first dim in the shape tuple when it is implicit
dim -= 1
input_val = topk_out0
shape = input_val.shape
@@ -1355,52 +1239,44 @@
target: Target,
args: Tuple[Argument, ...],
kwargs: Dict[str, Argument],
name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
- return add_acc_ops_full_reduce(
- network, target, args, kwargs, name, trt.ReduceOperation.MAX
- )
+ return add_acc_ops_full_reduce(network, target, args, kwargs, name, trt.ReduceOperation.MAX)
@tensorrt_converter(acc_ops.min_full_reduce, no_implicit_batch_dim=True)
def acc_ops_min_full_reduce(
network: TRTNetwork,
target: Target,
args: Tuple[Argument, ...],
kwargs: Dict[str, Argument],
name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
- return add_acc_ops_full_reduce(
- network, target, args, kwargs, name, trt.ReduceOperation.MIN
- )
+ return add_acc_ops_full_reduce(network, target, args, kwargs, name, trt.ReduceOperation.MIN)
@tensorrt_converter(acc_ops.max_dim_reduce)
def acc_ops_max_dim_reduce(
network: TRTNetwork,
target: Target,
args: Tuple[Argument, ...],
kwargs: Dict[str, Argument],
name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
- return add_acc_ops_dim_reduce(
- network, target, args, kwargs, name, trt.ReduceOperation.MAX
- )
+ return add_acc_ops_dim_reduce(network, target, args, kwargs, name, trt.ReduceOperation.MAX)
@tensorrt_converter(acc_ops.min_dim_reduce)
def acc_ops_min_dim_reduce(
network: TRTNetwork,
target: Target,
args: Tuple[Argument, ...],
kwargs: Dict[str, Argument],
name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
- return add_acc_ops_dim_reduce(
- network, target, args, kwargs, name, trt.ReduceOperation.MIN
- )
+ return add_acc_ops_dim_reduce(network, target, args, kwargs, name, trt.ReduceOperation.MIN)
@tensorrt_converter(acc_ops.maximum)
def acc_ops_maximum(
network: TRTNetwork,
@@ -1503,32 +1379,24 @@
args: Tuple[Argument, ...],
kwargs: Dict[str, Argument],
name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
if network.has_implicit_batch_dimension:
- raise RuntimeError(
- "The `logical_and` function should be called with explicit batch dimension."
- )
+ raise RuntimeError("The `logical_and` function should be called with explicit batch dimension.")
input_t = kwargs["input"]
other_t = kwargs["other"]
# we only support both inputs are bool type
if target == acc_ops.bitwise_and:
def check_is_bool(input_t):
if isinstance(input_t, TRTTensor):
- assert (
- input_t.dtype == trt.bool
- ), "We currently do not support input is non-bool"
+ assert input_t.dtype == trt.bool, "We currently do not support input is non-bool"
elif isinstance(input_t, torch.Tensor):
- assert (
- input_t.dtype == torch.bool
- ), "We currently do not support input is non-bool"
+ assert input_t.dtype == torch.bool, "We currently do not support input is non-bool"
else:
- assert isinstance(
- input_t.bool
- ), "We currently do not support input is non-bool"
+ assert isinstance(input_t.bool), "We currently do not support input is non-bool"
check_is_bool(input_t)
check_is_bool(other_t)
input_t = get_trt_tensor(network, input_t, f"{name}_input_t")
@@ -1536,13 +1404,11 @@
if input_t.dtype != trt.bool:
input_t = type_cast(network, target, f"{name}_input", input_t, trt.bool)
if other_t.dtype != trt.bool:
other_t = type_cast(network, target, f"{name}_other", other_t, trt.bool)
- return add_binary_elementwise_layer(
- network, input_t, other_t, trt.ElementWiseOperation.AND, target, name
- )
+ return add_binary_elementwise_layer(network, input_t, other_t, trt.ElementWiseOperation.AND, target, name)
@tensorrt_converter(acc_ops.ne, no_implicit_batch_dim=True)
def acc_ops_ne(
network: TRTNetwork,
@@ -1550,24 +1416,20 @@
args: Tuple[Argument, ...],
kwargs: Dict[str, Argument],
name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
if network.has_implicit_batch_dimension:
- raise RuntimeError(
- "The `ne` function should be called with explicit batch dimension."
- )
+ raise RuntimeError("The `ne` function should be called with explicit batch dimension.")
input_t = kwargs["input"]
other_t = kwargs["other"]
input_t = get_trt_tensor(network, input_t, f"{name}_input_t")
other_t = get_trt_tensor(network, other_t, f"{name}_other_t")
input_t, other_t = dtype_uniform(network, target, name, input_t, other_t)
- eq_t = add_binary_elementwise_layer(
- network, input_t, other_t, trt.ElementWiseOperation.EQUAL, target, name
- )
+ eq_t = add_binary_elementwise_layer(network, input_t, other_t, trt.ElementWiseOperation.EQUAL, target, name)
return add_unary_layer(network, eq_t, trt.UnaryOperation.NOT, target, name)
@tensorrt_converter(acc_ops.eq, no_implicit_batch_dim=True)
@@ -1577,24 +1439,20 @@
args: Tuple[Argument, ...],
kwargs: Dict[str, Argument],
name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
if network.has_implicit_batch_dimension:
- raise RuntimeError(
- "The `eq` function should be called with explicit batch dimension."
- )
+ raise RuntimeError("The `eq` function should be called with explicit batch dimension.")
input_t = kwargs["input"]
other_t = kwargs["other"]
input_t = get_trt_tensor(network, input_t, f"{name}_input_t")
other_t = get_trt_tensor(network, other_t, f"{name}_other_t")
input_t, other_t = dtype_uniform(network, target, name, input_t, other_t)
- return add_binary_elementwise_layer(
- network, input_t, other_t, trt.ElementWiseOperation.EQUAL, target, name
- )
+ return add_binary_elementwise_layer(network, input_t, other_t, trt.ElementWiseOperation.EQUAL, target, name)
@tensorrt_converter(acc_ops.gt, no_implicit_batch_dim=True)
def acc_ops_gt(
network: TRTNetwork,
@@ -1602,24 +1460,20 @@
args: Tuple[Argument, ...],
kwargs: Dict[str, Argument],
name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
if network.has_implicit_batch_dimension:
- raise RuntimeError(
- "The `gt` function should be called with explicit batch dimension."
- )
+ raise RuntimeError("The `gt` function should be called with explicit batch dimension.")
input_t = kwargs["input"]
other_t = kwargs["other"]
input_t = get_trt_tensor(network, input_t, f"{name}_input_t")
other_t = get_trt_tensor(network, other_t, f"{name}_other_t")
input_t, other_t = dtype_uniform(network, target, name, input_t, other_t)
- return add_binary_elementwise_layer(
- network, input_t, other_t, trt.ElementWiseOperation.GREATER, target, name
- )
+ return add_binary_elementwise_layer(network, input_t, other_t, trt.ElementWiseOperation.GREATER, target, name)
@tensorrt_converter(acc_ops.lt, no_implicit_batch_dim=True)
def acc_ops_lt(
network: TRTNetwork,
@@ -1627,24 +1481,20 @@
args: Tuple[Argument, ...],
kwargs: Dict[str, Argument],
name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
if network.has_implicit_batch_dimension:
- raise RuntimeError(
- "The `le` function should be called with explicit batch dimension."
- )
+ raise RuntimeError("The `le` function should be called with explicit batch dimension.")
input_t = kwargs["input"]
other_t = kwargs["other"]
input_t = get_trt_tensor(network, input_t, f"{name}_input_t")
other_t = get_trt_tensor(network, other_t, f"{name}_other_t")
input_t, other_t = dtype_uniform(network, target, name, input_t, other_t)
- return add_binary_elementwise_layer(
- network, input_t, other_t, trt.ElementWiseOperation.LESS, target, name
- )
+ return add_binary_elementwise_layer(network, input_t, other_t, trt.ElementWiseOperation.LESS, target, name)
@tensorrt_converter(acc_ops.logical_or, no_implicit_batch_dim=True)
def acc_ops_logical_or(
network: TRTNetwork,
@@ -1652,13 +1502,11 @@
args: Tuple[Argument, ...],
kwargs: Dict[str, Argument],
name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
if network.has_implicit_batch_dimension:
- raise RuntimeError(
- "The `logical_or` function should be called with explicit batch dimension."
- )
+ raise RuntimeError("The `logical_or` function should be called with explicit batch dimension.")
input_t = kwargs["input"]
other_t = kwargs["other"]
if isinstance(other_t, (torch.Tensor, bool)):
if isinstance(other_t, bool):
@@ -1675,13 +1523,11 @@
layer_o = network.add_identity(other_t)
layer_o.set_output_type(0, trt.bool)
set_layer_name(layer_o, target, f"{name}_other_dtype_change")
other_t = layer_o.get_output(0)
- return add_binary_elementwise_layer(
- network, input_t, other_t, trt.ElementWiseOperation.OR, target, name
- )
+ return add_binary_elementwise_layer(network, input_t, other_t, trt.ElementWiseOperation.OR, target, name)
@tensorrt_converter(acc_ops.logical_xor, no_implicit_batch_dim=True)
def acc_ops_logical_xor(
network: TRTNetwork,
@@ -1689,13 +1535,11 @@
args: Tuple[Argument, ...],
kwargs: Dict[str, Argument],
name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
if network.has_implicit_batch_dimension:
- raise RuntimeError(
- "The `logical_xor` function should be called with explicit batch dimension."
- )
+ raise RuntimeError("The `logical_xor` function should be called with explicit batch dimension.")
input_t = kwargs["input"]
other_t = kwargs["other"]
if isinstance(other_t, (torch.Tensor, bool)):
if isinstance(other_t, bool):
@@ -1712,13 +1556,11 @@
layer_o = network.add_identity(other_t)
layer_o.set_output_type(0, trt.bool)
set_layer_name(layer_o, target, f"{name}_other_dtype_change")
other_t = layer_o.get_output(0)
- return add_binary_elementwise_layer(
- network, input_t, other_t, trt.ElementWiseOperation.XOR, target, name
- )
+ return add_binary_elementwise_layer(network, input_t, other_t, trt.ElementWiseOperation.XOR, target, name)
# T113156424 Have some accuracy problems in hf_T5.
# [TRT] [W] Weights [name=isinf_1_inf_t]: Converted FP32 value in weights (either FP32 infinity or FP32 value outside FP16 range) to corresponding FP16 infinity. If this is not the desired behavior, please modify the weights or retrain with regularization to reduce the magnitude of the weights.
# @tensorrt_converter(acc_ops.isinf)
@@ -1764,26 +1606,19 @@
kwargs: Dict[str, Argument],
name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
input_t = kwargs["input"]
if not isinstance(input_t, TRTTensor):
- raise RuntimeError(
- f"isinf received input {input_t} that is not part "
- "of the TensorRT region!"
- )
+ raise RuntimeError(f"isinf received input {input_t} that is not part " "of the TensorRT region!")
if input_t.dtype in (trt.float32, trt.float16, trt.int32):
- comp_t = torch.zeros(tuple([*input_t.shape])).to(
- torch_dtype_from_trt(input_t.dtype)
- )
+ comp_t = torch.zeros(tuple([*input_t.shape])).to(torch_dtype_from_trt(input_t.dtype))
comp_t = get_trt_tensor(network, comp_t, f"{name}_comp_t")
kwargs_new = {"input": input_t, "other": comp_t}
eq_output = acc_ops_eq(network, target, None, kwargs_new, name + "_eq")
kwargs_new = {"input": eq_output}
- not_output = acc_ops_logical_not(
- network, target, None, kwargs_new, name + "_not"
- )
+ not_output = acc_ops_logical_not(network, target, None, kwargs_new, name + "_not")
else:
not_output = input_t
# cast bool result to int
int_output = type_cast(network, target, f"{name}_cast_int", not_output, trt.int32)
# sum
@@ -1809,13 +1644,11 @@
args: Tuple[Argument, ...],
kwargs: Dict[str, Argument],
name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
# NOTE: TRT doesnt currently implement fmod so we need multiple operations to perform it
- trunc_div_value = trunc_div(
- kwargs["input"], kwargs["other"], network, target, name + "_trunc_div"
- )
+ trunc_div_value = trunc_div(kwargs["input"], kwargs["other"], network, target, name + "_trunc_div")
prod_value = add_binary_elementwise_layer(
network,
trunc_div_value,
kwargs["other"],
trt.ElementWiseOperation.PROD,
@@ -1907,14 +1740,11 @@
kwargs: Dict[str, Argument],
name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
input_trt = kwargs["input"]
if not isinstance(input_trt, TRTTensor):
- raise RuntimeError(
- f"Max_pool1d received input {input_trt} that is not part "
- "of the TensorRT region!"
- )
+ raise RuntimeError(f"Max_pool1d received input {input_trt} that is not part " "of the TensorRT region!")
# adds unsqueeze layer -> max pool 2d -> squeeze layer to emulate max pool 1d.
unsqueeze_layer = network.add_shuffle(input=input_trt)
unsqueeze_layer.reshape_dims = tuple([*input_trt.shape, 1])
set_layer_name(unsqueeze_layer, target, name + "_unsqueeze")
@@ -1929,25 +1759,16 @@
ceil_mode = kwargs["ceil_mode"]
if len(stride) == 0 or stride[0] == None:
stride = kernel_size
- if any(
- [
- not isinstance(param, int)
- for param in [kernel_size[0], stride[0], padding[0], dilation[0]]
- ]
- ):
- raise RuntimeError(
- f"Parameters kernel_size, stride, padding, and dilation should be of type int."
- )
+ if any([not isinstance(param, int) for param in [kernel_size[0], stride[0], padding[0], dilation[0]]]):
+ raise RuntimeError(f"Parameters kernel_size, stride, padding, and dilation should be of type int.")
if dilation[0] != 1:
raise RuntimeError(f"Only support dilation=1 for maxpool, but got {dilation}")
- max_pooling_layer = network.add_pooling(
- input=input_trt, type=trt.PoolingType.MAX, window_size=(kernel_size[0], 1)
- )
+ max_pooling_layer = network.add_pooling(input=input_trt, type=trt.PoolingType.MAX, window_size=(kernel_size[0], 1))
max_pooling_layer.stride_nd = stride + (1,)
max_pooling_layer.padding_nd = padding + (0,)
set_layer_name(max_pooling_layer, target, name)
if ceil_mode:
@@ -1969,14 +1790,11 @@
name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
input_val = kwargs["input"]
if not isinstance(input_val, TRTTensor):
- raise RuntimeError(
- f"MaxPool2d received input {input_val} that is not part "
- "of the TensorRT region!"
- )
+ raise RuntimeError(f"MaxPool2d received input {input_val} that is not part " "of the TensorRT region!")
extend_len = 2 if target == acc_ops.max_pool2d else 3
kernel_size = extend_attr_to_tuple(kwargs["kernel_size"], extend_len)
stride = extend_attr_to_tuple(kwargs["stride"], extend_len)
padding = extend_attr_to_tuple(kwargs["padding"], extend_len)
dilation = extend_attr_to_tuple(kwargs["dilation"], extend_len)
@@ -1985,17 +1803,13 @@
if len(stride) == 0 or stride[0] == None:
stride = kernel_size
ones = (1,) * extend_len
if dilation != ones:
- raise RuntimeError(
- f"Only support dilation=(1, 1) for maxpool, but got {dilation}"
- )
-
- layer = network.add_pooling_nd(
- input=input_val, type=trt.PoolingType.MAX, window_size=kernel_size
- )
+ raise RuntimeError(f"Only support dilation=(1, 1) for maxpool, but got {dilation}")
+
+ layer = network.add_pooling_nd(input=input_val, type=trt.PoolingType.MAX, window_size=kernel_size)
layer.stride_nd = stride
layer.padding_nd = padding
set_layer_name(layer, target, name)
if ceil_mode:
@@ -2013,23 +1827,18 @@
name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
input_val = kwargs["input"]
if not isinstance(input_val, TRTTensor):
- raise RuntimeError(
- f"squeeze received input {input_val} that is not part "
- "of the TensorRT region!"
- )
+ raise RuntimeError(f"squeeze received input {input_val} that is not part " "of the TensorRT region!")
dim = cast(Optional[int], kwargs["dim"] if "dim" in kwargs else None)
# Squeeze with dim=None would only work in explicit batch dim mode without any dynamic
# dim, which is a very rare case. For now we just claim not supporting dim=None.
assert dim is not None, "We don't support dim=None right now for squeeze."
- dim = get_positive_dim(
- dim, len(input_val.shape) + (1 if network.has_implicit_batch_dimension else 0)
- )
+ dim = get_positive_dim(dim, len(input_val.shape) + (1 if network.has_implicit_batch_dimension else 0))
if network.has_implicit_batch_dimension:
assert dim != 0, "We don't support squeeze batch dim when it's implicit."
dim -= 1
assert input_val.shape[dim] != -1, "We don't support squeeze dynamic dim."
@@ -2176,35 +1985,26 @@
name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
input_t = kwargs["input"]
input_val = get_trt_tensor(network, input_t, f"{name}_input_t")
if not isinstance(input_val, TRTTensor):
- raise RuntimeError(
- f"unsqueeze received input {input_val} that is not part "
- "of the TensorRT region!"
- )
+ raise RuntimeError(f"unsqueeze received input {input_val} that is not part " "of the TensorRT region!")
dim = cast(int, kwargs["dim"])
input_shape = input_val.shape
- input_shape_size = (
- len(input_val.shape) + 1
- if network.has_implicit_batch_dimension
- else len(input_val.shape)
- )
+ input_shape_size = len(input_val.shape) + 1 if network.has_implicit_batch_dimension else len(input_val.shape)
dim = get_positive_dim(dim, input_shape_size + 1)
if network.has_implicit_batch_dimension:
assert dim != 0
dim -= 1
assert (
len(get_dynamic_dims(input_val.shape)) <= 1
), "Currently we don't support unsqueeze with more than one dynamic dims."
layer = network.add_shuffle(input_val)
- layer.reshape_dims = (
- tuple(input_val.shape)[:dim] + (1,) + tuple(input_val.shape)[dim:]
- )
+ layer.reshape_dims = tuple(input_val.shape)[:dim] + (1,) + tuple(input_val.shape)[dim:]
set_layer_name(layer, target, name)
return layer.get_output(0)
@tensorrt_converter(acc_ops.topk)
@@ -2216,14 +2016,11 @@
name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
input_val = kwargs["input"]
if not isinstance(input_val, TRTTensor):
- raise RuntimeError(
- f"topk received input {input_val} that is not part "
- "of the TensorRT region!"
- )
+ raise RuntimeError(f"topk received input {input_val} that is not part " "of the TensorRT region!")
if kwargs["sorted"] and kwargs["k"] != 1:
raise RuntimeError("Currently we don't support sorted=True in topk.")
if not network.has_implicit_batch_dimension and len(input_val.shape) <= 1:
@@ -2253,40 +2050,28 @@
name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
input_val = kwargs["input"]
if not isinstance(input_val, TRTTensor):
- raise RuntimeError(
- f"AdaptiveAvgPool2d received input {input_val} that is not part "
- "of the TensorRT region!"
- )
+ raise RuntimeError(f"AdaptiveAvgPool2d received input {input_val} that is not part " "of the TensorRT region!")
extend_len = 2 if target == acc_ops.adaptive_avg_pool2d else 3
assert all(
input_val.shape[-(i + 1)] != -1 for i in range(extend_len)
), "AdaptiveAvgPool2d and AdaptiveAvgPool3d currently doesn't support dynamic shapes for last two dims."
- output_size = cast(
- Sequence[int], extend_attr_to_tuple(kwargs["output_size"], extend_len)
- )
+ output_size = cast(Sequence[int], extend_attr_to_tuple(kwargs["output_size"], extend_len))
for input_dim, output_dim in zip(input_val.shape[-extend_len:], output_size):
if input_dim % output_dim != 0:
raise RuntimeError(
"For AdaptiveAvgPool, input dim has to be integer multiple of output dim."
f"Got input dim {input_dim}, output dim {output_dim}"
)
- stride = tuple(
- input_val.shape[-extend_len + i] // output_size[i] for i in range(extend_len)
- )
- kernel_size = tuple(
- input_val.shape[-extend_len + i] - (output_size[i] - 1) * stride[i]
- for i in range(extend_len)
- )
- layer = network.add_pooling_nd(
- input=input_val, type=trt.PoolingType.AVERAGE, window_size=kernel_size
- )
+ stride = tuple(input_val.shape[-extend_len + i] // output_size[i] for i in range(extend_len))
+ kernel_size = tuple(input_val.shape[-extend_len + i] - (output_size[i] - 1) * stride[i] for i in range(extend_len))
+ layer = network.add_pooling_nd(input=input_val, type=trt.PoolingType.AVERAGE, window_size=kernel_size)
layer.stride_nd = stride
set_layer_name(layer, target, name)
return layer.get_output(0)
@@ -2300,14 +2085,11 @@
name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
input_val = kwargs["input"]
if not isinstance(input_val, TRTTensor):
- raise RuntimeError(
- f"AvgPool1d received input {input_val} that is not part "
- "of the TensorRT region!"
- )
+ raise RuntimeError(f"AvgPool1d received input {input_val} that is not part " "of the TensorRT region!")
kernel_size = extend_attr_to_tuple(kwargs["kernel_size"], 1)
stride = extend_attr_to_tuple(kwargs["stride"], 1)
padding = extend_attr_to_tuple(kwargs["padding"], 1)
ceil_mode = kwargs["ceil_mode"]
@@ -2319,13 +2101,11 @@
shuffle_layer = network.add_shuffle(input_val)
shuffle_layer.reshape_dims = tuple(input_val.shape) + (1,)
set_layer_name(shuffle_layer, target, name + "_shuffle1")
shuffle_out = shuffle_layer.get_output(0)
- layer = network.add_pooling_nd(
- input=shuffle_out, type=trt.PoolingType.AVERAGE, window_size=(kernel_size[0], 1)
- )
+ layer = network.add_pooling_nd(input=shuffle_out, type=trt.PoolingType.AVERAGE, window_size=(kernel_size[0], 1))
layer.stride_nd = stride + (1,)
layer.padding_nd = padding + (0,)
layer.average_count_excludes_padding = False if count_include_pad else True
set_layer_name(layer, target, name)
@@ -2349,14 +2129,11 @@
name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
input_val = kwargs["input"]
if not isinstance(input_val, TRTTensor):
- raise RuntimeError(
- f"AvgPool2d received input {input_val} that is not part "
- "of the TensorRT region!"
- )
+ raise RuntimeError(f"AvgPool2d received input {input_val} that is not part " "of the TensorRT region!")
kernel_size = extend_attr_to_tuple(kwargs["kernel_size"], 2)
stride = extend_attr_to_tuple(kwargs["stride"], 2)
padding = extend_attr_to_tuple(kwargs["padding"], 2)
ceil_mode = kwargs["ceil_mode"]
@@ -2367,13 +2144,11 @@
stride = kernel_size
if divisor_override:
raise RuntimeError("TensorRT does not support divisor_override.")
- layer = network.add_pooling(
- input=input_val, type=trt.PoolingType.AVERAGE, window_size=kernel_size
- )
+ layer = network.add_pooling(input=input_val, type=trt.PoolingType.AVERAGE, window_size=kernel_size)
layer.stride = stride
layer.padding = padding
layer.average_count_excludes_padding = False if count_include_pad else True
set_layer_name(layer, target, name)
@@ -2433,23 +2208,18 @@
name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
input_val = kwargs["input"]
if not isinstance(input_val, TRTTensor):
- raise RuntimeError(
- f"slice_tensor received input {input_val} that is not part "
- "of the TensorRT region!"
- )
+ raise RuntimeError(f"slice_tensor received input {input_val} that is not part " "of the TensorRT region!")
ranks = len(input_val.shape) + (1 if network.has_implicit_batch_dimension else 0)
dim = get_positive_dim(cast(int, kwargs["dim"]), ranks)
dynamic_shape = has_dynamic_shape(input_val.shape)
if network.has_implicit_batch_dimension:
if dim == 0:
- raise RuntimeError(
- f"We do not support slice_tensor at batch dim when it's implicit, got {dim}!"
- )
+ raise RuntimeError(f"We do not support slice_tensor at batch dim when it's implicit, got {dim}!")
dim = dim - 1
else:
if dynamic_shape:
# Check whether slice target dim is dynamic shape dim
assert input_val.shape[dim] != -1, "Can't chunk on dynamic shape dimension!"
@@ -2463,13 +2233,11 @@
stride[dim] = step_int
output_shape = list(input_val.shape)
output_shape[dim] = (stop_int - start_int) // step_int
if dynamic_shape > 0:
- output_shape = get_shape_with_dynamic_shape(
- network, output_shape, input_val, target, name
- )
+ output_shape = get_shape_with_dynamic_shape(network, output_shape, input_val, target, name)
layer = network.add_slice(
input_val,
start=start,
shape=[] if dynamic_shape else output_shape,
stride=stride,
@@ -2502,13 +2270,11 @@
shape = [input_val.shape[i] if shape[i] == -1 else shape[i] for i in range(ranks)]
inshape = tuple(input_val.shape)
shape = tuple(shape)
start = tuple([0] * ranks)
- stride = tuple(
- [int(i == o) for i, o in zip(inshape, shape)]
- ) # stride == 1 if dimensions match, 0 otherwise
+ stride = tuple([int(i == o) for i, o in zip(inshape, shape)]) # stride == 1 if dimensions match, 0 otherwise
layer = network.add_slice(input_val, start=start, shape=shape, stride=stride)
set_layer_name(layer, target, name)
return layer.get_output(0)
@@ -2615,13 +2381,11 @@
) -> Union[TRTTensor, Sequence[TRTTensor]]:
input_t = kwargs["input"]
mask_t = kwargs["mask"]
value_t = kwargs["value"]
if network.has_implicit_batch_dimension:
- raise RuntimeError(
- "We don't support masked_fill with implicit batch dimension due to select layer!"
- )
+ raise RuntimeError("We don't support masked_fill with implicit batch dimension due to select layer!")
shape = list(input_t.shape)
mask_shape = list(mask_t.shape)
assert type(value_t) in (
@@ -2674,14 +2438,11 @@
name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
input_val = kwargs["input"]
if not isinstance(input_val, TRTTensor):
- raise RuntimeError(
- f"split received input {input_val} that is not part "
- "of the TensorRT region!"
- )
+ raise RuntimeError(f"split received input {input_val} that is not part " "of the TensorRT region!")
dim = cast(int, kwargs["dim"])
dynamic_shape = has_dynamic_shape(input_val.shape)
if network.has_implicit_batch_dimension:
assert dim != 0, "Can't split on batch dim when it's implicit!"
@@ -2695,28 +2456,22 @@
start = [0] * len(input_val.shape)
stride = [1] * len(start)
offset = 0
num_splits = (input_val.shape[dim] + split_size - 1) // split_size
if num_splits < 1:
- raise RuntimeError(
- f"Invalid split: {input_val.shape[dim]} with split_size={split_size}"
- )
+ raise RuntimeError(f"Invalid split: {input_val.shape[dim]} with split_size={split_size}")
max_offset = input_val.shape[dim]
# add slice layers
output = []
for i in range(num_splits):
shape = list(input_val.shape)
shape[dim] = min(split_size, cast(int, max_offset - offset))
start[dim] = offset
if dynamic_shape:
- shape = get_shape_with_dynamic_shape(
- network, shape, input_val, target, f"{name}_shape_{i}"
- )
- layer = network.add_slice(
- input_val, start=start, shape=[] if dynamic_shape else shape, stride=stride
- )
+ shape = get_shape_with_dynamic_shape(network, shape, input_val, target, f"{name}_shape_{i}")
+ layer = network.add_slice(input_val, start=start, shape=[] if dynamic_shape else shape, stride=stride)
if dynamic_shape:
layer.set_input(2, shape)
offset += split_size
set_layer_name(layer, target, f"{name}_{i}")
output.append(layer.get_output(0))
@@ -2732,19 +2487,15 @@
name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
input_val = kwargs["input"]
if not isinstance(input_val, TRTTensor):
- raise RuntimeError(
- f"Linear received input {input_val} that is not part "
- "of the TensorRT region!"
- )
+ raise RuntimeError(f"Linear received input {input_val} that is not part " "of the TensorRT region!")
dynamic_dims = get_dynamic_dims(input_val.shape)
assert len(dynamic_dims) < 2 and input_val.shape[-1] != -1, (
- "Currently we only support one dynmaic "
- "dim for linear and it can't be the last dim."
+ "Currently we only support one dynmaic " "dim for linear and it can't be the last dim."
)
if isinstance(kwargs["weight"], torch.Tensor):
weight = get_trt_tensor(network, kwargs["weight"].t(), f"{name}_weight")
weight_op = trt.MatrixOperation.NONE
@@ -2760,13 +2511,11 @@
preset_diff -= 1
input_op = trt.MatrixOperation.VECTOR
else:
input_op = trt.MatrixOperation.NONE
- input_val, weight = broadcast(
- network, input_val, weight, f"{name}_input", f"{name}_weight", preset_diff
- )
+ input_val, weight = broadcast(network, input_val, weight, f"{name}_input", f"{name}_weight", preset_diff)
matmul_layer = network.add_matrix_multiply(input_val, input_op, weight, weight_op)
set_layer_name(matmul_layer, target, f"{name}_matmul")
res = matmul_layer.get_output(0)
if kwargs["bias"] is not None:
@@ -2782,16 +2531,11 @@
return res
def add_clamp(network, input, val, op):
acc_ops_clamp_shape = (1,) * len(input.shape) # broadcast all dimensions
- acc_ops_clamp_tensor = (
- val
- * torch.ones(acc_ops_clamp_shape, dtype=torch_dtype_from_trt(input.dtype))
- .cpu()
- .numpy()
- )
+ acc_ops_clamp_tensor = val * torch.ones(acc_ops_clamp_shape, dtype=torch_dtype_from_trt(input.dtype)).cpu().numpy()
acc_ops_clamp_trt = network.add_constant(acc_ops_clamp_shape, acc_ops_clamp_tensor)
layer = network.add_elementwise(input, acc_ops_clamp_trt.get_output(0), op)
return layer
@@ -2807,25 +2551,18 @@
input_val = kwargs["input"]
min_val = kwargs["min"]
max_val = kwargs["max"]
if not isinstance(input_val, TRTTensor):
- raise RuntimeError(
- f"Clamp received input {input_val} that is not part "
- "of the TensorRT region!"
- )
+ raise RuntimeError(f"Clamp received input {input_val} that is not part " "of the TensorRT region!")
if min_val is not None:
- clamp_min_layer = add_clamp(
- network, input_val, min_val, trt.ElementWiseOperation.MAX
- )
+ clamp_min_layer = add_clamp(network, input_val, min_val, trt.ElementWiseOperation.MAX)
set_layer_name(clamp_min_layer, target, f"{name}_clamp_min")
input_val = clamp_min_layer.get_output(0)
if max_val is not None:
- clamp_max_layer = add_clamp(
- network, input_val, max_val, trt.ElementWiseOperation.MIN
- )
+ clamp_max_layer = add_clamp(network, input_val, max_val, trt.ElementWiseOperation.MIN)
set_layer_name(clamp_max_layer, target, f"{name}_clamp_max")
input_val = clamp_max_layer.get_output(0)
return input_val
@@ -2883,30 +2620,22 @@
def slice_to_trt_params(py_slice, dim_size):
"""
Convert python slice to TensorRT slice layer parameters.
"""
- start = (
- get_positive_dim(py_slice.start, dim_size) if py_slice.start != None else 0
- )
+ start = get_positive_dim(py_slice.start, dim_size) if py_slice.start != None else 0
stride = py_slice.step if py_slice.step != None else 1
- stop = (
- get_positive_dim(py_slice.stop, dim_size)
- if py_slice.stop != None
- else dim_size
- )
+ stop = get_positive_dim(py_slice.stop, dim_size) if py_slice.stop != None else dim_size
size = math.ceil((stop - start) * 1.0 / stride)
return start, size, stride
if network.has_implicit_batch_dimension:
# Raise an error if it's trying to subscript batch dimension unless it's
# slice(None, None, None).
batch_subscript = slices[0]
if batch_subscript not in [slice(None, None, None), slice(0, None, None)]:
- raise RuntimeError(
- f"{name}: Can't subscript batch dimension when it's implicit. Got {slices}"
- )
+ raise RuntimeError(f"{name}: Can't subscript batch dimension when it's implicit. Got {slices}")
# Remove batch_dim subscript
slices = slices[1:]
# Replace ellipsis with expanded slices.
@@ -2995,13 +2724,11 @@
) -> Union[TRTTensor, Sequence[TRTTensor]]:
tensors = kwargs["tensors"]
dim = kwargs["dim"]
if any(not isinstance(t, TRTTensor) for t in tensors): # type: ignore[union-attr]
- raise RuntimeError(
- f"cat received inputs {tensors} that is not part " "of the TensorRT region!"
- )
+ raise RuntimeError(f"cat received inputs {tensors} that is not part " "of the TensorRT region!")
layer = network.add_concatenation(inputs=tensors)
if dim < 0:
if network.has_implicit_batch_dimension:
dim = len(tensors[0].shape) + 1 + dim
else:
@@ -3023,13 +2750,11 @@
input_val = get_trt_tensor(network, kwargs["input"], f"{name}_input")
other_val = get_trt_tensor(network, kwargs["other"], f"{name}_other")
for i in [input_val, other_val]:
if not isinstance(i, TRTTensor):
- raise RuntimeError(
- f"matmul received input {i} that is not part of the TensorRT region!"
- )
+ raise RuntimeError(f"matmul received input {i} that is not part of the TensorRT region!")
input_matrix_op = other_matrix_op = trt.MatrixOperation.NONE
preset_diff = 0
if len(input_val.shape) == 1:
@@ -3038,16 +2763,12 @@
if len(other_val.shape) == 1:
preset_diff += 1
other_matrix_op = trt.MatrixOperation.VECTOR
- input_val, other_val = broadcast(
- network, input_val, other_val, f"{name}_input", f"{name}_other", preset_diff
- )
- layer = network.add_matrix_multiply(
- input_val, input_matrix_op, other_val, other_matrix_op
- )
+ input_val, other_val = broadcast(network, input_val, other_val, f"{name}_input", f"{name}_other", preset_diff)
+ layer = network.add_matrix_multiply(input_val, input_matrix_op, other_val, other_matrix_op)
set_layer_name(layer, target, name)
return layer.get_output(0)
@tensorrt_converter(acc_ops.hardsigmoid)
@@ -3059,14 +2780,11 @@
name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
input_val = kwargs["input"]
if not isinstance(input_val, TRTTensor):
- raise RuntimeError(
- f"Hard sigmoid received input {input_val} that is not part "
- "of the TensorRT region!"
- )
+ raise RuntimeError(f"Hard sigmoid received input {input_val} that is not part " "of the TensorRT region!")
return add_activation_layer(
network,
input_val,
trt.ActivationType.HARD_SIGMOID,
@@ -3086,18 +2804,13 @@
name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
input_val = kwargs["input"]
if not isinstance(input_val, TRTTensor):
- raise RuntimeError(
- f"Sigmoid received input {input_val} that is not part "
- "of the TensorRT region!"
- )
-
- return add_activation_layer(
- network, input_val, trt.ActivationType.SIGMOID, target, name
- )
+ raise RuntimeError(f"Sigmoid received input {input_val} that is not part " "of the TensorRT region!")
+
+ return add_activation_layer(network, input_val, trt.ActivationType.SIGMOID, target, name)
@tensorrt_converter(acc_ops.permute)
def acc_ops_permute(
network: TRTNetwork,
@@ -3113,14 +2826,11 @@
else:
index = kwargs["permutation"]
permutation = [get_positive_dim(i, ranks) for i in cast(Sequence[int], index)]
if not isinstance(input_val, TRTTensor):
- raise RuntimeError(
- f"permute received input {input_val} that is not part "
- "of the TensorRT region!"
- )
+ raise RuntimeError(f"permute received input {input_val} that is not part " "of the TensorRT region!")
if network.has_implicit_batch_dimension:
assert permutation[0] == 0, "Can't permute batch dimension when it's implicit."
permutation = [i - 1 for i in permutation[1:]]
@@ -3139,14 +2849,11 @@
name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
input_val = get_trt_tensor(network, kwargs["input"], f"{name}_input")
if not isinstance(input_val, TRTTensor):
- raise RuntimeError(
- f"{name} received input {input_val} that is not part "
- "of the TensorRT region!"
- )
+ raise RuntimeError(f"{name} received input {input_val} that is not part " "of the TensorRT region!")
qparams = kwargs["acc_out_ty"].qparams # type: ignore[misc]
q_scale = qparams["scale"]
q_zero_point = qparams["zero_point"]
dtype = kwargs["acc_out_ty"].dtype # type: ignore[misc]
@@ -3157,13 +2864,11 @@
)
if q_zero_point != 0:
raise RuntimeError(f"Only support zero_point == 0, get {q_zero_point}")
- scale_layer = network.add_constant(
- (1,), trt.Weights(np.ascontiguousarray([float(q_scale)], dtype=np.float32))
- )
+ scale_layer = network.add_constant((1,), trt.Weights(np.ascontiguousarray([float(q_scale)], dtype=np.float32)))
scale_layer.name = input_val.name + ".per_tensor_quant.scale"
scale = scale_layer.get_output(0)
# assert trt.__version__ > "8.0", "Explicit quantize op is only supported in "
# "TensorRT 8.0 or above, current TensorRT version:" + trt.__version__
layer = network.add_quantize(input=input_val, scale=scale)
@@ -3181,14 +2886,11 @@
name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
input_val = get_trt_tensor(network, kwargs["input"], f"{name}_input")
if not isinstance(input_val, TRTTensor):
- raise RuntimeError(
- f"{name} received input {input_val} that is not part "
- "of the TensorRT region!"
- )
+ raise RuntimeError(f"{name} received input {input_val} that is not part " "of the TensorRT region!")
qparams = kwargs["acc_out_ty"].qparams # type: ignore[misc]
q_per_channel_scales = qparams["scale"]
q_per_channel_zero_points = qparams["zero_point"]
q_per_channel_axis = qparams["axis"]
@@ -3201,17 +2903,13 @@
# Make sure zero_points are all 0 because only symmetric quantization
# is supported in TensorRT
if not torch.equal(
q_per_channel_zero_points,
- torch.zeros(
- q_per_channel_zero_points.shape, dtype=q_per_channel_zero_points.dtype
- ),
+ torch.zeros(q_per_channel_zero_points.shape, dtype=q_per_channel_zero_points.dtype),
):
- raise RuntimeError(
- f"Only support zero_point == 0, get {q_per_channel_zero_points}"
- )
+ raise RuntimeError(f"Only support zero_point == 0, get {q_per_channel_zero_points}")
if not torch.all(torch.ge(q_per_channel_scales, 0)):
raise RuntimeError(f"All scale values must be >= 0, get {q_per_channel_scales}")
scale_layer = network.add_constant(
@@ -3238,14 +2936,11 @@
) -> Union[TRTTensor, Sequence[TRTTensor]]:
input_val = kwargs["input"]
input_val_tensor_meta = kwargs["_itensor_to_tensor_meta"][input_val] # type: ignore[index]
if not isinstance(input_val, TRTTensor):
- raise RuntimeError(
- f"{name} received input {input_val} that is not part "
- "of the TensorRT region!"
- )
+ raise RuntimeError(f"{name} received input {input_val} that is not part " "of the TensorRT region!")
qparams = input_val_tensor_meta.qparams # type: ignore[misc]
qscheme = qparams["qscheme"]
if qscheme == torch.per_tensor_affine:
q_scale = qparams["scale"]
@@ -3256,30 +2951,25 @@
raise RuntimeError(f"Only support zero_point == 0, get {q_zero_point}")
elif qscheme == torch.per_channel_affine:
q_scale = qparams["scale"]
q_zero_point = qparams["zero_point"]
q_axis = qparams["axis"]
- assert isinstance(
- q_scale, immutable_list
- ), "expected q_scale to be immutable_list got {}".format(type(q_scale))
+ assert isinstance(q_scale, immutable_list), "expected q_scale to be immutable_list got {}".format(type(q_scale))
scale_shape = (len(q_scale),)
if any(x != 0 for x in q_zero_point):
raise RuntimeError(f"Only support zero_point == 0, get {q_zero_point}")
else:
raise RuntimeError("Unsupported qscheme in dequantize: {qscheme}")
dtype = input_val_tensor_meta.dtype # type: ignore[misc]
if dtype not in (torch.quint8, torch.qint8, torch.qint32):
raise RuntimeError(
- "Only support (torch.quint8, torch.qint8, torch.qint32) "
- f"quantized type in dequantize, get {dtype}."
+ "Only support (torch.quint8, torch.qint8, torch.qint32) " f"quantized type in dequantize, get {dtype}."
)
- scale_layer = network.add_constant(
- scale_shape, trt.Weights(np.ascontiguousarray(q_scale, dtype=np.float32))
- )
+ scale_layer = network.add_constant(scale_shape, trt.Weights(np.ascontiguousarray(q_scale, dtype=np.float32)))
scale_layer.name = input_val.name + ".dequant.scale"
scale = scale_layer.get_output(0)
# assert trt.__version__ > "8.0", "Explicit dequantize op is only supported in "
# "TensorRT 8.0 or above, current TensorRT version:" + trt.__version__
layer = network.add_dequantize(input=input_val, scale=scale)
@@ -3296,24 +2986,17 @@
kwargs: Dict[str, Argument],
name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
input_val = kwargs["input"]
if not isinstance(input_val, TRTTensor):
- raise RuntimeError(
- f"GELU received input {input_val} that is not part "
- "of the TensorRT region!"
- )
+ raise RuntimeError(f"GELU received input {input_val} that is not part " "of the TensorRT region!")
if network.has_implicit_batch_dimension:
- raise RuntimeError(
- "GeLU converter currently doesn't support implicit batch dimension"
- )
+ raise RuntimeError("GeLU converter currently doesn't support implicit batch dimension")
plugin_name = "CustomGeluPluginDynamic"
# type_id 0 for float32, 1 for float16
- type_id = trt.PluginField(
- "type_id", np.array(0, dtype=np.int32), trt.PluginFieldType.INT32
- )
+ type_id = trt.PluginField("type_id", np.array(0, dtype=np.int32), trt.PluginFieldType.INT32)
field_collection = TRTPluginFieldCollection([type_id])
plugin_version = "1"
plugin = get_trt_plugin(plugin_name, field_collection, plugin_version)
@@ -3334,14 +3017,11 @@
chunks = cast(int, kwargs["chunks"])
dim = cast(int, kwargs["dim"])
input_dim_size = len(input_val.shape) # type: ignore[union-attr]
if not isinstance(input_val, TRTTensor):
- raise RuntimeError(
- f"chunk received input {input_val} that is not part "
- "of the TensorRT region!"
- )
+ raise RuntimeError(f"chunk received input {input_val} that is not part " "of the TensorRT region!")
dynamic_shape = has_dynamic_shape(input_val.shape)
if network.has_implicit_batch_dimension:
input_dim_size += 1
dim = get_positive_dim(dim, input_dim_size)
@@ -3371,17 +3051,13 @@
output = []
for i in range(chunks):
shape = list(input_val.shape)
shape[dim] = min(split_size, max_offset - offset)
if dynamic_shape:
- shape = get_shape_with_dynamic_shape(
- network, shape, input_val, target, f"{name}_{i}"
- )
+ shape = get_shape_with_dynamic_shape(network, shape, input_val, target, f"{name}_{i}")
start[dim] = offset
- layer = network.add_slice(
- input_val, start=start, shape=[] if dynamic_shape else shape, stride=stride
- )
+ layer = network.add_slice(input_val, start=start, shape=[] if dynamic_shape else shape, stride=stride)
if dynamic_shape:
layer.set_input(2, shape)
offset += split_size
set_layer_name(layer, target, f"{name}_{i}")
output.append(layer.get_output(0))
@@ -3400,18 +3076,13 @@
dim = cast(int, kwargs["dim"])
input_shape = input_val.shape # type: ignore[union-attr]
input_dim_size = len(input_val.shape) # type: ignore[union-attr]
if not isinstance(input_val, TRTTensor):
- raise RuntimeError(
- f"cumsum received input {input_val} that is not part "
- "of the TensorRT region!"
- )
+ raise RuntimeError(f"cumsum received input {input_val} that is not part " "of the TensorRT region!")
if network.has_implicit_batch_dimension:
- raise RuntimeError(
- "cumsum converter currently doesn't support implicit batch dimension"
- )
+ raise RuntimeError("cumsum converter currently doesn't support implicit batch dimension")
dim = get_positive_dim(dim, input_dim_size)
loop = network.add_loop()
trip_limit = None
if input_shape[dim] > 0:
axis = torch.tensor(input_shape[dim], dtype=torch.int32)
@@ -3427,13 +3098,11 @@
loop.add_trip_limit(trip_limit, trt.TripLimit(0))
iterator = loop.add_iterator(input_val, dim, False)
data = iterator.get_output(0)
new_dims = tuple(data.shape)
zero_tensor = torch.zeros(new_dims, dtype=trt_dtype_to_torch_dtype(input_val.dtype))
- zero_tensor = network.add_constant(
- zero_tensor.shape, to_numpy(zero_tensor)
- ).get_output(0)
+ zero_tensor = network.add_constant(zero_tensor.shape, to_numpy(zero_tensor)).get_output(0)
running_sum = loop.add_recurrence(zero_tensor)
set_layer_name(running_sum, target, f"{name}_running_sum_1")
running_sum_tensor = running_sum.get_output(0)
@@ -3476,14 +3145,11 @@
name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
input_val = kwargs["input"]
if not isinstance(input_val, TRTTensor):
- raise RuntimeError(
- f"hardtanh received input {input_val} that is not part "
- "of the TensorRT region!"
- )
+ raise RuntimeError(f"hardtanh received input {input_val} that is not part " "of the TensorRT region!")
return add_activation_layer(
network,
input_val,
trt.ActivationType.CLIP,
@@ -3507,26 +3173,19 @@
scale_factor = kwargs["scale_factor"]
mode = kwargs["mode"]
align_corners = kwargs["align_corners"]
if not isinstance(input_val, TRTTensor):
- raise RuntimeError(
- f"interpolate received input {input_val} that is not part "
- "of the TensorRT region!"
- )
+ raise RuntimeError(f"interpolate received input {input_val} that is not part " "of the TensorRT region!")
dim = input_val.shape
ranks = len(input_val.shape)
if network.has_implicit_batch_dimension:
- assert (
- ranks >= 2 and ranks <= 4
- ), "Interpolate expects inputs are 3D,4D,5D in shape"
+ assert ranks >= 2 and ranks <= 4, "Interpolate expects inputs are 3D,4D,5D in shape"
ranks = ranks - 1
else:
- assert (
- ranks >= 3 and ranks <= 5
- ), "Interpolate expects inputs are 3D,4D,5D in shape"
+ assert ranks >= 3 and ranks <= 5, "Interpolate expects inputs are 3D,4D,5D in shape"
ranks = ranks - 2
layer = network.add_resize(input_val)
if network.has_implicit_batch_dimension:
if size != None:
@@ -3555,13 +3214,11 @@
layer.resize_mode = trt.ResizeMode.LINEAR
else:
layer.resize_mode = trt.ResizeMode.NEAREST
if align_corners != None:
- layer.coordinate_transformation = (
- trt.ResizeCoordinateTransformation.ALIGN_CORNERS
- )
+ layer.coordinate_transformation = trt.ResizeCoordinateTransformation.ALIGN_CORNERS
set_layer_name(layer, target, name)
return layer.get_output(0)
@@ -3579,13 +3236,11 @@
if dtype_val is None:
dtype_val = input_val.dtype
dtype_val = torch_dtype_from_trt(dtype_val)
device_val = kwargs.get("device")
- assert (
- device_val == "cuda" or device_val == None
- ), f"device is not `cuda` but {device_val}"
+ assert device_val == "cuda" or device_val == None, f"device is not `cuda` but {device_val}"
weight = torch.ones(size_val, dtype=dtype_val)
return get_trt_tensor(network, weight, f"{name}_weight")
@@ -3603,13 +3258,11 @@
if dtype_val is None:
dtype_val = input_val.dtype
dtype_val = torch_dtype_from_trt(dtype_val)
device_val = kwargs.get("device")
- assert (
- device_val == "cuda" or device_val == None
- ), f"device is not `cuda` but {device_val}"
+ assert device_val == "cuda" or device_val == None, f"device is not `cuda` but {device_val}"
weight = torch.zeros(size_val, dtype=dtype_val)
return get_trt_tensor(network, weight, f"{name}_weight")
@@ -3634,13 +3287,11 @@
input_val[i] = get_trt_tensor(network, input_source, name + f"_input_source{i}")
if const_flag:
for i, input_source in enumerate(input_val):
if input_source.dtype != trt.float32:
- input_val[i] = type_cast(
- network, target, f"{name}_input_cast{i}", input_source, trt.float32
- )
+ input_val[i] = type_cast(network, target, f"{name}_input_cast{i}", input_source, trt.float32)
einsum_layer = network.add_einsum(inputs=input_val, equation=equation)
return einsum_layer.get_output(0)
@tensorrt_converter(acc_ops.as_strided)
--- py/torch_tensorrt/fx/test/converters/acc_op/test_prod.py 2022-08-12 19:16:11.712868 +0000
+++ py/torch_tensorrt/fx/test/converters/acc_op/test_prod.py 2022-08-12 19:20:00.656595 +0000
@@ -70,13 +70,11 @@
inputs,
expected_ops={expected_acc_op},
test_implicit_batch_dim=(dim != 0),
)
- @parameterized.expand(
- [(f"{acc_ops.prod.__name__}_no_dim_no_keepdim", torch.prod, acc_ops.prod)]
- )
+ @parameterized.expand([(f"{acc_ops.prod.__name__}_no_dim_no_keepdim", torch.prod, acc_ops.prod)])
def test_prod_all_dims(
self,
test_name,
op,
expected_acc_op,
@@ -107,12 +105,10 @@
dtype=torch.float32,
shape_ranges=[((1, 1, 1, 1), (2, 3, 4, 5), (2, 3, 10, 10))],
),
]
- self.run_test_with_dynamic_shape(
- Prod(), input_specs, expected_ops={acc_ops.prod}
- )
+ self.run_test_with_dynamic_shape(Prod(), input_specs, expected_ops={acc_ops.prod})
if __name__ == "__main__":
run_tests()
--- py/torch_tensorrt/fx/test/converters/acc_op/test_reduce_ops.py 2022-08-12 19:16:11.712868 +0000
+++ py/torch_tensorrt/fx/test/converters/acc_op/test_reduce_ops.py 2022-08-12 19:20:00.693911 +0000
@@ -50,16 +50,11 @@
inputs,
expected_ops={expected_acc_op},
test_implicit_batch_dim=(dim != 0),
)
- @parameterized.expand(
- [
- (f"{acc_op.__name__}_no_dim_no_keepdim", op, acc_op)
- for op, acc_op in reduce_ops
- ]
- )
+ @parameterized.expand([(f"{acc_op.__name__}_no_dim_no_keepdim", op, acc_op) for op, acc_op in reduce_ops])
def test_reduce_all_dims(
self,
test_name,
op,
expected_acc_op,
@@ -74,16 +69,11 @@
inputs,
expected_ops={expected_acc_op},
test_implicit_batch_dim=False,
)
- @parameterized.expand(
- [
- (f"{acc_op.__name__}_no_dim_no_keepdim", op, acc_op)
- for op, acc_op in reduce_ops
- ]
- )
+ @parameterized.expand([(f"{acc_op.__name__}_no_dim_no_keepdim", op, acc_op) for op, acc_op in reduce_ops])
def test_reduce_all_dims_with_dynamic_shape_four_dimensions(
self,
test_name,
op,
expected_acc_op,
@@ -97,12 +87,10 @@
shape=(-1, -1, -1, -1),
dtype=torch.float32,
shape_ranges=[((1, 1, 1, 1), (3, 3, 3, 3), (3, 3, 3, 3))],
),
]
- self.run_test_with_dynamic_shape(
- Reduce(), input_specs, expected_ops={expected_acc_op}
- )
+ self.run_test_with_dynamic_shape(Reduce(), input_specs, expected_ops={expected_acc_op})
if __name__ == "__main__":
run_tests()
--- py/torch_tensorrt/fx/test/converters/acc_op/test_tile.py 2022-08-12 19:16:11.712868 +0000
+++ py/torch_tensorrt/fx/test/converters/acc_op/test_tile.py 2022-08-12 19:20:00.870688 +0000
@@ -26,14 +26,11 @@
inputs = [torch.randn(*input_shape)]
self.run_test(
Tile(dims),
inputs,
expected_ops={acc_ops.tile},
- test_implicit_batch_dim=(
- len(input_shape) > len(dims)
- or (len(input_shape) == len(dims) and dims[0] == 1)
- ),
+ test_implicit_batch_dim=(len(input_shape) > len(dims) or (len(input_shape) == len(dims) and dims[0] == 1)),
)
@parameterized.expand(
[
("same_num_dims", (-1, 2, 3), (1, 2, 2)),
@@ -62,13 +59,11 @@
tuple(i if i != -1 else 3 for i in shape),
)
],
),
]
- self.run_test_with_dynamic_shape(
- Tile(dims), input_specs, expected_ops={acc_ops.tile}
- )
+ self.run_test_with_dynamic_shape(Tile(dims), input_specs, expected_ops={acc_ops.tile})
@parameterized.expand(
[
("all_dynamic_dim", (-1, -1), (1, 2, 2, 1)),
]
@@ -88,13 +83,11 @@
dtype=torch.float32,
shape_ranges=[((1, 1, 1, 3), (3, 3, 3, 3), (3, 3, 3, 3))],
),
]
- self.run_test_with_dynamic_shape(
- Tile(dims), input_specs, expected_ops={acc_ops.tile}
- )
+ self.run_test_with_dynamic_shape(Tile(dims), input_specs, expected_ops={acc_ops.tile})
def test_tile_non_int_dims(self):
class Tile(nn.Module):
def __init__(self):
super().__init__()
@@ -103,13 +96,11 @@
y = y * 2
return torch.tile(x, (1, y.shape[1], y.shape[1]))
inputs = [torch.randn(2, 2, 3), torch.randn(2, 2, 3)]
batch_size_range = (1, 2, 3)
- input_specs = InputTensorSpec.from_tensors_with_dynamic_batch_size(
- inputs, batch_size_range
- )
+ input_specs = InputTensorSpec.from_tensors_with_dynamic_batch_size(inputs, batch_size_range)
self.run_test_with_dynamic_shape(
Tile(),
input_specs,
expected_ops={acc_ops.tile},
)
@@ -134,12 +125,10 @@
dtype=torch.float32,
shape_ranges=[((1, 1, 1, 3), (3, 3, 3, 3), (3, 3, 3, 3))],
),
]
- self.run_test_with_dynamic_shape(
- Tile(), input_specs, expected_ops={acc_ops.tile}
- )
+ self.run_test_with_dynamic_shape(Tile(), input_specs, expected_ops={acc_ops.tile})
if __name__ == "__main__":
run_tests()
--- py/torch_tensorrt/fx/test/converters/acc_op/test_topk.py 2022-08-12 19:16:11.712868 +0000
+++ py/torch_tensorrt/fx/test/converters/acc_op/test_topk.py 2022-08-12 19:20:00.924215 +0000
@@ -24,13 +24,11 @@
self.dim = dim
self.largest = largest
def forward(self, x):
if self.dim is not None:
- out = torch.topk(
- x, k=self.k, dim=self.dim, largest=self.largest, sorted=False
- )
+ out = torch.topk(x, k=self.k, dim=self.dim, largest=self.largest, sorted=False)
else:
out = torch.topk(x, k=self.k, largest=self.largest, sorted=False)
return out[0], out[1]
inputs = [torch.randn(1, 2, 3, 4)]
@@ -58,13 +56,11 @@
self.dim = dim
self.largest = largest
def forward(self, x):
if self.dim is not None:
- out = torch.topk(
- x, k=self.k, dim=self.dim, largest=self.largest, sorted=False
- )
+ out = torch.topk(x, k=self.k, dim=self.dim, largest=self.largest, sorted=False)
else:
out = torch.topk(x, k=self.k, largest=self.largest, sorted=False)
return out[0], out[1]
input_specs = [
@@ -73,12 +69,10 @@
dtype=torch.float32,
shape_ranges=[((1, 1, 1, 1), (3, 3, 3, 3), (3, 3, 3, 3))],
),
]
- self.run_test_with_dynamic_shape(
- TopK(k, dim), input_specs, expected_ops={acc_ops.topk}
- )
+ self.run_test_with_dynamic_shape(TopK(k, dim), input_specs, expected_ops={acc_ops.topk})
if __name__ == "__main__":
run_tests()
--- py/torch_tensorrt/fx/test/converters/acc_op/test_to_dtype.py 2022-08-12 19:16:11.712868 +0000
+++ py/torch_tensorrt/fx/test/converters/acc_op/test_to_dtype.py 2022-08-12 19:20:00.924826 +0000
@@ -51,13 +51,11 @@
input = torch.randn(2, 2).to(torch.float16)
inputs = [
input,
]
- self.run_test(
- To(), inputs, expected_ops={acc_ops.to_dtype}, test_implicit_batch_dim=False
- )
+ self.run_test(To(), inputs, expected_ops={acc_ops.to_dtype}, test_implicit_batch_dim=False)
def test_cuda_fp16(self):
class To(torch.nn.Module):
def forward(self, x):
return x.to(torch.device("cuda:0"), torch.float16)
@@ -106,13 +104,11 @@
dtype=torch.float16,
shape_ranges=[((1, 1, 3, 3), (3, 3, 3, 3), (3, 3, 3, 3))],
),
]
- self.run_test_with_dynamic_shape(
- To(), input_specs, expected_ops={acc_ops.to_dtype, acc_ops.add}
- )
+ self.run_test_with_dynamic_shape(To(), input_specs, expected_ops={acc_ops.to_dtype, acc_ops.add})
def test_device(self):
class To(torch.nn.Module):
def __init__(self):
super().__init__()
@@ -152,13 +148,11 @@
dtype=torch.float16,
shape_ranges=[((1, 1, 3, 3), (3, 3, 3, 3), (3, 3, 3, 3))],
),
]
- self.run_test_with_dynamic_shape(
- To(), input_specs, expected_ops={acc_ops.to_dtype, acc_ops.add}
- )
+ self.run_test_with_dynamic_shape(To(), input_specs, expected_ops={acc_ops.to_dtype, acc_ops.add})
def test_device_fp16(self):
class To(torch.nn.Module):
def __init__(self):
super().__init__()
@@ -244,13 +238,11 @@
dtype=torch.float32,
shape_ranges=[((1, 1, 1, 1), (3, 3, 3, 3), (3, 3, 3, 3))],
),
]
- self.run_test_with_dynamic_shape(
- To(), input_specs, expected_ops={acc_ops.to_dtype}
- )
+ self.run_test_with_dynamic_shape(To(), input_specs, expected_ops={acc_ops.to_dtype})
# Half is not suitable for dynamic shape
# Error: assert engine
# tensor.half()
@@ -307,12 +299,10 @@
dtype=torch.int,
shape_ranges=[((1, 1, 1, 1), (3, 3, 3, 3), (3, 3, 3, 3))],
),
]
- self.run_test_with_dynamic_shape(
- To(), input_specs, expected_ops={acc_ops.to_dtype}
- )
+ self.run_test_with_dynamic_shape(To(), input_specs, expected_ops={acc_ops.to_dtype})
if __name__ == "__main__":
run_tests()
--- py/torch_tensorrt/fx/test/converters/acc_op/test_unary_ops.py 2022-08-12 19:16:11.712868 +0000
+++ py/torch_tensorrt/fx/test/converters/acc_op/test_unary_ops.py 2022-08-12 19:20:01.054432 +0000
@@ -62,13 +62,11 @@
dtype=torch.float32,
shape_ranges=[((1, 1, 1, 1), (3, 3, 3, 3), (3, 3, 3, 3))],
),
]
- self.run_test_with_dynamic_shape(
- TestModule(orig_op), input_specs, expected_ops={expected_op}
- )
+ self.run_test_with_dynamic_shape(TestModule(orig_op), input_specs, expected_ops={expected_op})
class TestUnaryOpNotConverters(AccTestCase):
@parameterized.expand(
[
@@ -87,13 +85,11 @@
x = self.orig_op(x)
return self.orig_op(x)
m = TestModule(orig_op)
inputs = [torch.randn(2, 2, 3).to(input_dtype)]
- self.run_test(
- m, inputs, expected_ops={expected_op}, test_implicit_batch_dim=False
- )
+ self.run_test(m, inputs, expected_ops={expected_op}, test_implicit_batch_dim=False)
class TestUnaryOpNotConvertersWithDynamicShapeFourDimensions(AccTestCase):
@parameterized.expand(
[
@@ -118,13 +114,11 @@
dtype=torch.float32,
shape_ranges=[((1, 1, 1, 1), (3, 3, 3, 3), (3, 3, 3, 3))],
),
]
- self.run_test_with_dynamic_shape(
- TestModule(orig_op), input_specs, expected_ops={expected_op}
- )
+ self.run_test_with_dynamic_shape(TestModule(orig_op), input_specs, expected_ops={expected_op})
class TestUnaryRSQRTConverters(AccTestCase):
def test_unary_ops(self):
class TestModule(nn.Module):
@@ -148,12 +142,10 @@
dtype=torch.float32,
shape_ranges=[((1, 1, 1, 1), (3, 3, 3, 3), (3, 3, 3, 3))],
),
]
- self.run_test_with_dynamic_shape(
- TestModule(), input_specs, expected_ops={acc_ops.sqrt, acc_ops.reciprocal}
- )
+ self.run_test_with_dynamic_shape(TestModule(), input_specs, expected_ops={acc_ops.sqrt, acc_ops.reciprocal})
if __name__ == "__main__":
run_tests()
--- py/torch_tensorrt/fx/test/core/test_input_tensor_spec.py 2022-08-12 19:16:11.712868 +0000
+++ py/torch_tensorrt/fx/test/core/test_input_tensor_spec.py 2022-08-12 19:20:01.105902 +0000
@@ -35,26 +35,22 @@
self._validate_spec(spec, tensor)
def test_from_tensors_with_dynamic_batch_size(self):
tensors = [torch.randn(1, 2, 3), torch.randn(1, 4)]
batch_size_range = [2, 3, 4]
- specs = InputTensorSpec.from_tensors_with_dynamic_batch_size(
- tensors, batch_size_range
- )
+ specs = InputTensorSpec.from_tensors_with_dynamic_batch_size(tensors, batch_size_range)
for spec, tensor in zip(specs, tensors):
self._validate_spec(spec, tensor, dynamic_dims=[0])
for batch_size, shape in zip(batch_size_range, spec.shape_ranges[0]):
self.assertEqual(batch_size, shape[0])
self.assertSequenceEqual(tensor.shape[1:], shape[1:])
def test_from_tensors_with_dynamic_batch_size_different_batch_dims(self):
tensors = [torch.randn(1, 2, 3), torch.randn(2, 1, 4)]
batch_size_range = [2, 3, 4]
- specs = InputTensorSpec.from_tensors_with_dynamic_batch_size(
- tensors, batch_size_range, batch_dims=[0, 1]
- )
+ specs = InputTensorSpec.from_tensors_with_dynamic_batch_size(tensors, batch_size_range, batch_dims=[0, 1])
for i, spec_and_tensor in enumerate(zip(specs, tensors)):
spec, tensor = spec_and_tensor
self._validate_spec(spec, tensor, dynamic_dims=[i])
for batch_size, shape in zip(batch_size_range, spec.shape_ranges[0]):
@@ -62,13 +58,11 @@
tensor_shape = list(tensor.shape)
tensor_shape[i] = batch_size
self.assertSequenceEqual(tensor_shape, shape)
def test_generate_input_specs(self):
- lower_setting = LowerSetting(
- explicit_batch_dimension=False, max_batch_size=256, opt_profile_replica=2
- )
+ lower_setting = LowerSetting(explicit_batch_dimension=False, max_batch_size=256, opt_profile_replica=2)
# Implicit batch dim.
inputs = [torch.randn(1, 2, 3)]
specs = generate_input_specs(inputs, lower_setting)
for spec, tensor in zip(specs, inputs):
--- py/torch_tensorrt/fx/test/quant/test_quant_trt.py 2022-08-12 19:16:11.716868 +0000
+++ py/torch_tensorrt/fx/test/quant/test_quant_trt.py 2022-08-12 19:20:01.635277 +0000
@@ -46,13 +46,11 @@
shape_ranges=shape_ranges,
has_batch_dim=True,
)
]
- interp = TRTInterpreter(
- model, input_specs, explicit_batch_dimension=True, explicit_precision=True
- )
+ interp = TRTInterpreter(model, input_specs, explicit_batch_dimension=True, explicit_precision=True)
result = interp.run(lower_precision=LowerPrecision.INT8)
trt_mod = TRTModule(result.engine, result.input_names, result.output_names)
return trt_mod
@@ -65,13 +63,11 @@
),
weight=torch.ao.quantization.default_weight_observer,
)
self.trt_backend_config_dict = get_tensorrt_backend_config_dict()
- def _test_quantized_inputs_outputs(
- self, prepare_custom_config_dict, prepare_count_check, convert_count_check
- ):
+ def _test_quantized_inputs_outputs(self, prepare_custom_config_dict, prepare_count_check, convert_count_check):
"""
Test the option to have inputs and outputs of the graph quantized
"""
class M(torch.nn.Module):
@@ -113,13 +109,11 @@
# output of ref conv1 and output of ref conv2
ns.call_function(torch.quantize_per_tensor): 2,
# input of ref conv1 and input of ref conv2
ns.call_method("dequantize"): 2,
}
- self._test_quantized_inputs_outputs(
- prepare_custom_config_dict, prepare_count_check, convert_count_check
- )
+ self._test_quantized_inputs_outputs(prepare_custom_config_dict, prepare_count_check, convert_count_check)
def test_fp32_input_quantized_output(self):
prepare_custom_config_dict = {"output_quantized_idxs": [0]}
prepare_count_check = {
ns.call_module(torch.ao.quantization.MinMaxObserver): 3,
@@ -128,13 +122,11 @@
# input, output of conv1 and output of conv2
ns.call_function(torch.quantize_per_tensor): 3,
# input of conv1, conv2
ns.call_method("dequantize"): 2,
}
- self._test_quantized_inputs_outputs(
- prepare_custom_config_dict, prepare_count_check, convert_count_check
- )
+ self._test_quantized_inputs_outputs(prepare_custom_config_dict, prepare_count_check, convert_count_check)
def test_quantized_input_fp32_output(self):
prepare_custom_config_dict = {"input_quantized_idxs": [0]}
prepare_count_check = {
ns.call_module(torch.ao.quantization.MinMaxObserver): 2,
@@ -143,26 +135,22 @@
# output of conv1, conv2
ns.call_function(torch.quantize_per_tensor): 2,
# input of ref conv1, input of ref conv2, final output
ns.call_method("dequantize"): 3,
}
- self._test_quantized_inputs_outputs(
- prepare_custom_config_dict, prepare_count_check, convert_count_check
- )
+ self._test_quantized_inputs_outputs(prepare_custom_config_dict, prepare_count_check, convert_count_check)
def test_fp32_input_fp32_output(self):
prepare_custom_config_dict = {}
prepare_count_check = {
ns.call_module(torch.ao.quantization.MinMaxObserver): 3,
}
convert_count_check = {
ns.call_function(torch.quantize_per_tensor): 3,
ns.call_method("dequantize"): 3,
}
- self._test_quantized_inputs_outputs(
- prepare_custom_config_dict, prepare_count_check, convert_count_check
- )
+ self._test_quantized_inputs_outputs(prepare_custom_config_dict, prepare_count_check, convert_count_check)
def _test_standalone_module(
self,
interface_config,
prepare_count_check,
@@ -213,20 +201,14 @@
data = torch.randn(1, 1, 1, 1)
# instantiate M and RefM and align the parameters
original_m = M().eval()
original_ref_m = RefM().eval()
- original_ref_m.conv1.weight = torch.nn.Parameter(
- original_m.conv.weight.detach()
- )
+ original_ref_m.conv1.weight = torch.nn.Parameter(original_m.conv.weight.detach())
original_ref_m.conv1.bias = torch.nn.Parameter(original_m.conv.bias.detach())
- original_ref_m.conv2.weight = torch.nn.Parameter(
- original_m.standalone.conv.weight.detach()
- )
- original_ref_m.conv2.bias = torch.nn.Parameter(
- original_m.standalone.conv.bias.detach()
- )
+ original_ref_m.conv2.weight = torch.nn.Parameter(original_m.standalone.conv.weight.detach())
+ original_ref_m.conv2.bias = torch.nn.Parameter(original_m.standalone.conv.bias.detach())
sm_example_inputs = (data,)
prepare_config = {
"standalone_module_name": [
(
@@ -253,20 +235,16 @@
backend_config=backend_config_dict,
)
# calibration
m(data)
self.checkGraphModuleNodes(m, expected_node_occurrence=prepare_count_check)
- self.checkGraphModuleNodes(
- m.standalone, expected_node_occurrence=standalone_prepare_count_check
- )
+ self.checkGraphModuleNodes(m.standalone, expected_node_occurrence=standalone_prepare_count_check)
# check converted/quantized model
m = convert_to_reference_fx(m, backend_config=backend_config_dict)
self.checkGraphModuleNodes(m, expected_node_occurrence=convert_count_check)
- self.checkGraphModuleNodes(
- m.standalone, expected_node_occurrence=standalone_convert_count_check
- )
+ self.checkGraphModuleNodes(m.standalone, expected_node_occurrence=standalone_convert_count_check)
res = m(data)
# quantize the reference model
ref_m = prepare_fx(
original_ref_m_copy,
@@ -285,17 +263,13 @@
"output_quantized_idxs": [], # float output
}
interface_config = float_interface_config
# input and output of first conv, observer for standalone module
# will be inserted in the standalone module itself
- prepare_count_check = {
- ns.call_module(torch.ao.quantization.HistogramObserver): 2
- }
+ prepare_count_check = {ns.call_module(torch.ao.quantization.HistogramObserver): 2}
# for input and output of conv in the standalone module
- standalone_prepare_count_check = {
- ns.call_module(torch.ao.quantization.HistogramObserver): 2
- }
+ standalone_prepare_count_check = {ns.call_module(torch.ao.quantization.HistogramObserver): 2}
convert_count_check = {
# input and output of reference conv
ns.call_function(torch.quantize_per_tensor): 2,
ns.call_module(nnqr.Conv2d): 1,
ns.call_method("dequantize"): 2,
@@ -351,17 +325,13 @@
"root_module": torch.nn.Conv2d,
"reference_quantized_module_for_root": torch.nn.quantized._reference.Conv2d,
}
custom_backend_config_dict = {"configs": [conv_module_config]}
# observer for input and output of first conv
- prepare_count_check = {
- ns.call_module(torch.ao.quantization.HistogramObserver): 2
- }
+ prepare_count_check = {ns.call_module(torch.ao.quantization.HistogramObserver): 2}
# for output of conv in the standalone module
- standalone_prepare_count_check = {
- ns.call_module(torch.ao.quantization.HistogramObserver): 1
- }
+ standalone_prepare_count_check = {ns.call_module(torch.ao.quantization.HistogramObserver): 1}
convert_count_check = {
# quantizing input/output for reference conv
ns.call_function(torch.quantize_per_tensor): 2,
ns.call_module(nnqr.Conv2d): 1,
# dequantize the input of reference conv and
@@ -400,13 +370,11 @@
),
weight=torch.ao.quantization.default_weight_observer,
)
self.trt_backend_config_dict = get_tensorrt_backend_config_dict()
- def _test_module(
- self, m, inputs, shape_ranges, no_prepare=None, no_convert=None, is_qat=False
- ):
+ def _test_module(self, m, inputs, shape_ranges, no_prepare=None, no_convert=None, is_qat=False):
"""
Args:
m: the float module we want to test
inputs: list of inputs for the module
shape_ranges: a list of shape_range, where every shape_range is a tuple of
@@ -468,13 +436,11 @@
def forward(self, x):
return self.relu(self.conv(x))
# just testing conv2d since conv1d and conv3d are not supported in fx2trt
- for dim, has_relu, f_relu, is_qat in itertools.product(
- [1, 2], [True, False], [True, False], [True, False]
- ):
+ for dim, has_relu, f_relu, is_qat in itertools.product([1, 2], [True, False], [True, False], [True, False]):
# when has_relu=False, we have torch.nn.Identity, which would introduce
# extra quant-dequat pair
no_convert = {
ns.call_function(torch.quantize_per_tensor): 2 + int(not has_relu),
ns.call_method("dequantize"): 2 + int(not has_relu),
@@ -510,13 +476,11 @@
return self.relu(self.linear(x))
linear_input = torch.rand(8, 5)
shape_ranges = [((1, 5), (5, 5), (10, 5))]
- for has_relu, f_relu, is_qat in itertools.product(
- [True, False], [True, False], [True, False]
- ):
+ for has_relu, f_relu, is_qat in itertools.product([True, False], [True, False], [True, False]):
# when has_relu=False, we have torch.nn.Identity, which would introduce
# extra quant-dequat pair
no_convert = {
ns.call_function(torch.quantize_per_tensor): 2 + int(not has_relu),
ns.call_method("dequantize"): 2 + int(not has_relu),
@@ -662,13 +626,11 @@
ns.call_function(torch.addmm): 1,
ns.call_method("dequantize"): 3,
}
self.checkGraphModuleNodes(quantized, expected_node_occurrence=node_occurrence)
- @unittest.skip(
- "This is not supported yet, we can enable the test after it's supported"
- )
+ @unittest.skip("This is not supported yet, we can enable the test after it's supported")
def test_conv_add(self):
class M(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv = torch.nn.Conv2d(3, 3, 3)
@@ -828,13 +790,11 @@
self.checkGraphModuleNodes(m, expected_node_occurrence=node_occurrence)
standalone_node_occurrence = {
# output of the standalone module
ns.call_module(torch.ao.quantization.HistogramObserver): 1,
}
- self.checkGraphModuleNodes(
- m.standalone, expected_node_occurrence=standalone_node_occurrence
- )
+ self.checkGraphModuleNodes(m.standalone, expected_node_occurrence=standalone_node_occurrence)
m = convert_to_reference_fx(m, backend_config=backend_config_dict)
node_occurrence = {
# two inputs for standalone module
ns.call_function(torch.quantize_per_tensor): 2,
ns.call_module(nn.Conv2d): 1,
@@ -847,13 +807,11 @@
ns.call_module(nn.Conv2d): 1,
ns.call_module(torch.nn.ReLU): 1,
# two input and one output for the pattern in standalone module
ns.call_method("dequantize"): 3,
}
- self.checkGraphModuleNodes(
- m.standalone, expected_node_occurrence=standalone_node_occurrence
- )
+ self.checkGraphModuleNodes(m.standalone, expected_node_occurrence=standalone_node_occurrence)
def test_quant_dequant_not_fold(self):
class LinearModule(torch.nn.Module):
def __init__(self):
super().__init__()
--- py/torch_tensorrt/fx/tools/common_fx2trt.py 2022-08-12 19:16:11.716868 +0000
+++ py/torch_tensorrt/fx/tools/common_fx2trt.py 2022-08-12 19:20:01.932333 +0000
@@ -29,13 +29,11 @@
"""
target_atoms = target.split(".")
attr_itr = mod
for i, atom in enumerate(target_atoms):
if not hasattr(attr_itr, atom):
- raise RuntimeError(
- f"Node referenced nonexistent target {'.'.join(target_atoms[:i])}"
- )
+ raise RuntimeError(f"Node referenced nonexistent target {'.'.join(target_atoms[:i])}")
attr_itr = getattr(attr_itr, atom)
return attr_itr
@unittest.skipIf(not torch.cuda.is_available(), "Skip because CUDA is not available")
@@ -82,13 +80,11 @@
end_event = torch.cuda.Event(enable_timing=True)
start_event.record()
outputs = trt_mod(*cuda_inputs)
end_event.record()
torch.cuda.synchronize()
- _LOGGER.info(
- f"TRT run time(s)= {(start_event.elapsed_time(end_event) * 1.0e-3)}"
- )
+ _LOGGER.info(f"TRT run time(s)= {(start_event.elapsed_time(end_event) * 1.0e-3)}")
if isinstance(outputs, torch.Tensor):
ref_outputs = [ref_outputs]
outputs = [outputs]
for out, ref in zip(outputs, ref_outputs):
@@ -126,26 +122,22 @@
mod.eval()
if len(expected_ops):
self.assert_has_op(mod, expected_ops)
interpreter_result = interpreter.run(
- lower_precision=LowerPrecision.FP16
- if fp16_mode
- else LowerPrecision.FP32
+ lower_precision=LowerPrecision.FP16 if fp16_mode else LowerPrecision.FP32
)
trt_mod = TRTModule(
interpreter_result.engine,
interpreter_result.input_names,
interpreter_result.output_names,
)
res_trt = trt_mod(*cuda_inputs).cpu()
res_cpu = mod(*inputs)
assert len(res_trt) == len(res_cpu)
assert len(res_cpu) == len(comparators)
- for output_trt, output_cpu, comparator in zip(
- res_trt, res_cpu, comparators
- ):
+ for output_trt, output_cpu, comparator in zip(res_trt, res_cpu, comparators):
comp_func = comparator[0]
args = comparator[1]
self.assertTrue(comp_func(output_trt, output_cpu, *args))
def run_test_with_error(self, mod, inputs, interpreter, expect_error):
@@ -165,13 +157,11 @@
if node.op == "call_module":
ops_in_mod.add(type(fetch_attr(mod, node.target)))
elif node.op in {"call_function", "call_method"}:
ops_in_mod.add(node.target)
- self.assertTrue(
- ops_in_mod >= ops, f"expected ops {ops}, actuall ops {ops_in_mod}"
- )
+ self.assertTrue(ops_in_mod >= ops, f"expected ops {ops}, actuall ops {ops_in_mod}")
def assert_unexpected_op(self, mod, ops):
for node in mod.graph.nodes:
if node.op == "call_module":
if type(fetch_attr(mod, node.target)) in ops:
@@ -204,13 +194,11 @@
# after we refactor the internal callsites to use this file
mod = torch.fx.symbolic_trace(mod)
shape_prop.ShapeProp(mod).propagate(*inputs)
mod = NormalizeArgs(mod).transform()
interp = TRTInterpreter(mod, InputTensorSpec.from_tensors(inputs))
- super().run_test_custom_compare_results(
- mod, inputs, expected_ops, interp, comparators, fp16_mode=fp16_mode
- )
+ super().run_test_custom_compare_results(mod, inputs, expected_ops, interp, comparators, fp16_mode=fp16_mode)
class AccTestCase(TRTTestCase):
def run_test(
self,
@@ -233,41 +221,31 @@
pass_tracer = chain_passes(*apply_passes)
mod = pass_tracer(mod, inputs)
if test_implicit_batch_dim:
interp = TRTInterpreter(mod, InputTensorSpec.from_tensors(inputs))
- super().run_test(
- mod, inputs, expected_ops, unexpected_ops, interp, rtol, atol, precision
- )
+ super().run_test(mod, inputs, expected_ops, unexpected_ops, interp, rtol, atol, precision)
if test_explicit_batch_dim:
- interp = TRTInterpreter(
- mod, InputTensorSpec.from_tensors(inputs), explicit_batch_dimension=True
- )
- super().run_test(
- mod, inputs, expected_ops, unexpected_ops, interp, rtol, atol, precision
- )
+ interp = TRTInterpreter(mod, InputTensorSpec.from_tensors(inputs), explicit_batch_dimension=True)
+ super().run_test(mod, inputs, expected_ops, unexpected_ops, interp, rtol, atol, precision)
if test_explicit_precision:
interp = TRTInterpreter(
mod,
InputTensorSpec.from_tensors(inputs),
explicit_precision=test_explicit_precision,
)
- super().run_test(
- mod, inputs, expected_ops, unexpected_ops, interp, rtol, atol
- )
+ super().run_test(mod, inputs, expected_ops, unexpected_ops, interp, rtol, atol)
interp = TRTInterpreter(
mod,
InputTensorSpec.from_tensors(inputs),
explicit_batch_dimension=True,
explicit_precision=test_explicit_precision,
)
- super().run_test(
- mod, inputs, expected_ops, unexpected_ops, interp, rtol, atol, precision
- )
+ super().run_test(mod, inputs, expected_ops, unexpected_ops, interp, rtol, atol, precision)
def run_test_with_assert_error(
self,
mod,
inputs,
@@ -281,13 +259,11 @@
if test_implicit_batch_dim:
interp = TRTInterpreter(mod, InputTensorSpec.from_tensors(inputs))
super().run_test_with_error(mod, inputs, interp, expect_error)
if test_explicit_batch_dim:
- interp = TRTInterpreter(
- mod, InputTensorSpec.from_tensors(inputs), explicit_batch_dimension=True
- )
+ interp = TRTInterpreter(mod, InputTensorSpec.from_tensors(inputs), explicit_batch_dimension=True)
super().run_test_with_error(mod, inputs, interp, expect_error)
def run_test_with_dynamic_shape(
self,
mod,
--- py/torch_tensorrt/fx/tools/trt_minimizer.py 2022-08-12 19:16:11.716868 +0000
+++ py/torch_tensorrt/fx/tools/trt_minimizer.py 2022-08-12 19:20:01.988637 +0000
@@ -8,16 +8,12 @@
from .. import InputTensorSpec, TRTInterpreter, TRTModule
_LOGGER: logging.Logger = logging.getLogger(__name__)
-def lower_mod_default(
- mod: torch.fx.GraphModule, inputs: Tensors, batch_size: Any = 2048
-) -> TRTModule:
- interp = TRTInterpreter(
- mod, InputTensorSpec.from_tensors(inputs), explicit_batch_dimension=True
- )
+def lower_mod_default(mod: torch.fx.GraphModule, inputs: Tensors, batch_size: Any = 2048) -> TRTModule:
+ interp = TRTInterpreter(mod, InputTensorSpec.from_tensors(inputs), explicit_batch_dimension=True)
interpreter_result = interp.run(max_batch_size=batch_size)
res_mod = TRTModule(
interpreter_result.engine,
interpreter_result.input_names,
interpreter_result.output_names,
@@ -37,13 +33,11 @@
module: torch.fx.GraphModule,
sample_input: Tensors,
compare_fn: Callable[[Any, Any, Any], Tuple[float, bool]],
settings: TensorRTMinizerSetting = TensorRTMinizerSetting(),
max_batch_size: Any = 2048,
- lower_fn: Callable[
- [torch.fx.GraphModule, Tensors, Any], TRTModule
- ] = lower_mod_default,
+ lower_fn: Callable[[torch.fx.GraphModule, Tensors, Any], TRTModule] = lower_mod_default,
):
self.lower_fn = lower_fn
self.max_batch_size = max_batch_size
super().__init__(module, sample_input, compare_fn, settings)
@@ -56,13 +50,11 @@
mod.eval()
try:
mod = self.lower_fn(mod, inputs, self.max_batch_size)
output = mod(*inputs)
except RuntimeError as e:
- raise net_min_base.FxNetMinimizerRunFuncError(
- f"Encounter an error when processing \n{mod.graph}\n {e}"
- )
+ raise net_min_base.FxNetMinimizerRunFuncError(f"Encounter an error when processing \n{mod.graph}\n {e}")
else:
return output
def get_nodes(self, start=None, end=None, enable_print=False):
nodes = self._collect_nodes(start, end)
--- py/torch_tensorrt/fx/tools/trt_splitter.py 2022-08-12 19:16:11.716868 +0000
+++ py/torch_tensorrt/fx/tools/trt_splitter.py 2022-08-12 19:20:02.057670 +0000
@@ -72,13 +72,11 @@
operator_support,
settings,
non_acc_submodule_name="_run_on_gpu_",
)
- def _lower_model_to_backend(
- self, mod: torch.fx.GraphModule, inputs: Iterable[torch.Tensor]
- ):
+ def _lower_model_to_backend(self, mod: torch.fx.GraphModule, inputs: Iterable[torch.Tensor]):
"""
Lower a GraphModule `mod` to TensorRT with `inputs`.
"""
# Current code for lowering is place-holder, subject to future change
# based on feeds model's actual status
--- py/torch_tensorrt/fx/tracer/acc_tracer/acc_tracer.py 2022-08-12 19:16:11.716868 +0000
+++ py/torch_tensorrt/fx/tracer/acc_tracer/acc_tracer.py 2022-08-12 19:20:02.310172 +0000
@@ -41,13 +41,11 @@
def __init__(self):
super().__init__()
self.exceptions_rewritten: Set[Type[Exception]] = set()
self.exceptions_bool_rewritten: Set[Type[Exception]] = set()
- def rewrite(
- self, fn: FunctionType
- ) -> Tuple[FunctionType, Set[Type[Exception]], Set[Type[Exception]]]:
+ def rewrite(self, fn: FunctionType) -> Tuple[FunctionType, Set[Type[Exception]], Set[Type[Exception]]]:
# Normalize the source lines
sourcelines, _ = inspect.getsourcelines(fn)
sourcelines = normalize_source_lines(sourcelines)
source = "".join(sourcelines)
@@ -139,12 +137,11 @@
return if_node
# Check that we actually have a builtin exception.
if (
not issubclass(exc_type, Exception)
- or getattr(getattr(exc_type, "__class__", None), "__module__", None)
- != "builtins"
+ or getattr(getattr(exc_type, "__class__", None), "__module__", None) != "builtins"
):
return if_node
# We need a ConditionalExceptionWrapper specialized for every kind of
# exception, so add it to exceptions_rewritten to remember for later to
@@ -156,23 +153,17 @@
# the If with, with args set as the If's condition and the string of the
# exception. The call to the self._conditional_exception_wrapper_*Error
# module is safe because the RewrittenModule will add it as an attr
# based on the returned exceptions_rewritten, and we assume we are
# currently modifying the AST of a method from a RewrittenModule.
- exc_wrapper_node = ast.parse(
- f"self.{_get_exception_wrapper_attr_name(exc_type)}()", mode="eval"
- )
+ exc_wrapper_node = ast.parse(f"self.{_get_exception_wrapper_attr_name(exc_type)}()", mode="eval")
assert isinstance(exc_wrapper_node, ast.Expression)
exc_wrapper_call_node = exc_wrapper_node.body
assert isinstance(exc_wrapper_call_node, ast.Call)
- if isinstance(if_node.test, ast.BoolOp) and isinstance(
- if_node.test.op, ast.And
- ):
+ if isinstance(if_node.test, ast.BoolOp) and isinstance(if_node.test.op, ast.And):
self.exceptions_bool_rewritten.add(exc_type)
- bool_wrapper_node = ast.parse(
- f"self.{_get_exception_wrapper_attr_name(exc_type)}_bool()", mode="eval"
- )
+ bool_wrapper_node = ast.parse(f"self.{_get_exception_wrapper_attr_name(exc_type)}_bool()", mode="eval")
assert isinstance(exc_wrapper_node, ast.Expression)
bool_wrapper_call_node = bool_wrapper_node.body
assert isinstance(exc_wrapper_call_node, ast.Call)
bool_wrapper_call_node.args = if_node.test.values
exc_wrapper_call_node.args = [
@@ -323,13 +314,11 @@
name_target[-1] == "_"
and name_target[0] != "_"
and not (name_target in allow_list)
and kind != "placeholder"
):
- raise RuntimeError(
- f"Tried to trace mutable operation {name_target}. FX only supports functional code"
- )
+ raise RuntimeError(f"Tried to trace mutable operation {name_target}. FX only supports functional code")
return self.graph.create_node(kind, target, args, kwargs, name, type_expr)
# List of modules that need rewriting to be supported for tracing.
@@ -384,13 +373,11 @@
# Write all of the non-dunder or special methods from base_class
# into RewrittenModule.
for method_name in dir(base_class):
method = getattr(base_class, method_name, None)
if method is None and method_name not in {"__doc__"}:
- _LOGGER.warning(
- f"{__qualname__} does not have attribute {method_name}"
- )
+ _LOGGER.warning(f"{__qualname__} does not have attribute {method_name}")
if builtins.type(method) is not FunctionType:
continue
# Always skip rewriting dunder methods, as they haven't (yet) been
@@ -437,13 +424,11 @@
# Recursively rewrite and copy all module attrs of this module.
for k, v in orig.__dict__.items():
if k == "_modules":
for mod_k, mod_v in v.items():
if getattr(mod_v, "_base_class_origin", type(mod_v)) in leaf_module_list: # type: ignore[operator]
- _LOGGER.info(
- f"Skip rewriting leaf module {type(mod_v)}"
- )
+ _LOGGER.info(f"Skip rewriting leaf module {type(mod_v)}")
self._modules[mod_k] = mod_v
else:
self._modules[mod_k] = rewrite_module(mod_v)
else:
self.__dict__[k] = v
@@ -475,25 +460,21 @@
"""
changed = False
for node in reversed(gm.graph.nodes):
if node.op == "call_module" and (
isinstance(gm.get_submodule(node.target), ConditionalExceptionWrapper)
- or isinstance(
- gm.get_submodule(node.target), ConditionalExceptionBoolCondWrapper
- )
+ or isinstance(gm.get_submodule(node.target), ConditionalExceptionBoolCondWrapper)
):
gm.graph.erase_node(node)
changed = True
return changed
def _replace_tensor_meta_with_rank(gm: torch.fx.GraphModule):
for node in gm.graph.nodes:
if node.op != "output" and "tensor_meta" in node.meta:
- node.meta["tensor_rank"] = acc_utils.map_tensor_metadata(
- node.meta["tensor_meta"], lambda x: len(x.shape)
- )
+ node.meta["tensor_rank"] = acc_utils.map_tensor_metadata(node.meta["tensor_meta"], lambda x: len(x.shape))
del node.meta["tensor_meta"]
def rewriter_base_trace(mod, ast_rewriter_allow_list, leaf_module_list):
rewritten_graph, rewritten_mod = AccRewritingTracer().trace(
This is our current configuration https://github.com/pytorch/TensorRT/blob/master/pyproject.toml Maybe it's the line width? |
Seems like in pytorch/pytorch they dont use the line-length argument. https://github.com/pytorch/pytorch/blob/2c089290b676a221817e48c7de42d1b2bd13609a/pyproject.toml#L22 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code conforms to C++ style guidelines
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code conforms to C++ style guidelines
that is also what I found internally.
So the black default uses 88. Let's try default. |
Yeah that is fine with me, if the fx changes pass both ours and your internal, I can handle reformatting the rest of the python code |
Alternatively if you run |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code conforms to Python style guidelines
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code conforms to C++ style guidelines
@narendasan could you take a look at the pybind issue in the |
@peri044 I think you looked at these classes of errors previously on the nightly channel (might have been in another FX pr) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code conforms to C++ style guidelines
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code conforms to Python style guidelines
Description
Fixes # (issue)
Type of change
Please delete options that are not relevant and/or add your own.
Checklist: