-
Notifications
You must be signed in to change notification settings - Fork 356
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
select_scatter decomp #2515
select_scatter decomp #2515
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code conforms to C++ style guidelines
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There are some changes that do not conform to Python style guidelines:
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code conforms to C++ style guidelines
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code conforms to Python style guidelines
e4c56cd
to
037fbcf
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code conforms to C++ style guidelines
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code conforms to Python style guidelines
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code conforms to C++ style guidelines
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There are some changes that do not conform to Python style guidelines:
--- /home/runner/work/TensorRT/TensorRT/tests/py/dynamo/lowering/test_decompositions.py 2024-01-02 18:24:49.853008+00:00
+++ /home/runner/work/TensorRT/TensorRT/tests/py/dynamo/lowering/test_decompositions.py 2024-01-02 18:27:03.483949+00:00
@@ -2,10 +2,11 @@
from torch.testing._internal.common_utils import TestCase, run_tests
import torch_tensorrt
from ..testing_utilities import DECIMALS_OF_AGREEMENT, lower_graph_testing
+
class TestLowering(TestCase):
def test_lowering_inplace_op(self):
class InPlace(torch.nn.Module):
def __init__(self, *args, **kwargs) -> None:
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code conforms to C++ style guidelines
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code conforms to Python style guidelines
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code conforms to Python style guidelines
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code conforms to C++ style guidelines
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code conforms to C++ style guidelines
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There are some changes that do not conform to Python style guidelines:
--- /home/runner/work/TensorRT/TensorRT/tests/py/dynamo/lowering/test_decompositions.py 2024-01-05 18:29:23.300495+00:00
+++ /home/runner/work/TensorRT/TensorRT/tests/py/dynamo/lowering/test_decompositions.py 2024-01-05 18:31:16.737105+00:00
@@ -481,11 +481,10 @@
0,
DECIMALS_OF_AGREEMENT,
f"Select_scatter TRT outputs don't match with the original model.",
)
-
def test_lowering_select_scatter_dimOne_module(self):
class selectScatter(torch.nn.Module):
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
@@ -544,7 +543,9 @@
max_diff,
0,
DECIMALS_OF_AGREEMENT,
f"Select_scatter TRT outputs don't match with the original model.",
)
+
+
if __name__ == "__main__":
run_tests()
81a2715
to
d5cec9f
Compare
d5cec9f
to
689105e
Compare
# input_tensor.shape[dim] = torch.le(index, input_tensor.shape[dim]) | ||
# check if the dim is less than shape | ||
if input_tensor.shape[dim] < index: | ||
raise AssertionError("The index should not be greater than dim") | ||
|
||
# expanding the src_tensor to have the same dimension as input_tensor | ||
# check if the dimension of the src tensor is same as slice tensor | ||
select_tensor = torch.select(input_tensor, dim, index) | ||
|
||
if select_tensor.shape != src_tensor.shape: | ||
raise AssertionError( | ||
"The slice tensor shape should be equal to the src tensor shape" | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Are the AssertionError
cases invalid in Torch, or just cases we can't support? Having AssertionError
s in lowering passes can cause models to inexplicably fail for users, so it is not preferable.
If these are invalid cases in Torch itself, then we do not need these assertions. If they are not supported by TRT, then we can instead return the original op and not lower.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes it is invalid test case in torch itself. So I think the right thing would be to do away with the assertion.
When you say return the original op, that would mean in those cases, we just do
if(condition == True):
return <unlowered_original_op>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok thanks for that clarification. I meant that if it would be invalid in Torch, you can assume it will not be the case that the condition would ever be encountered - otherwise the model should have failed earlier. Specifically, if select_tensor.shape == src_tensor.shape
is a requirement of select_scatter
, then it is safe to assume the inputs are valid inputs to that function, otherwise we can let Torch throw the error.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok got it. Thanks for the clarification!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There are some changes that do not conform to Python style guidelines:
--- /home/runner/work/TensorRT/TensorRT/examples/int8/training/vgg16/vgg16.py 2024-02-16 00:01:27.167252+00:00
+++ /home/runner/work/TensorRT/TensorRT/examples/int8/training/vgg16/vgg16.py 2024-02-16 00:03:16.025977+00:00
@@ -1,10 +1,11 @@
"""
# Reference
- [Very Deep Convolutional Networks for Large-Scale Image Recognition](
https://arxiv.org/abs/1409.1556) (ICLR 2015)
"""
+
import torch
import torch.nn as nn
import torch.nn.functional as F
from functools import reduce
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/_Device.py 2024-02-16 00:01:27.175252+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/_Device.py 2024-02-16 00:03:16.122237+00:00
@@ -30,16 +30,18 @@
gpu_id (int): Device ID for target GPU
dla_core (int): Core ID for target DLA core
allow_gpu_fallback (bool): Whether falling back to GPU if DLA cannot support an op should be allowed
"""
- device_type: Optional[
- trt.DeviceType
- ] = None #: Target device type (GPU or DLA). Set implicitly based on if dla_core is specified.
+ device_type: Optional[trt.DeviceType] = (
+ None #: Target device type (GPU or DLA). Set implicitly based on if dla_core is specified.
+ )
gpu_id: int = -1 #: Device ID for target GPU
dla_core: int = -1 #: Core ID for target DLA core
- allow_gpu_fallback: bool = False #: Whether falling back to GPU if DLA cannot support an op should be allowed
+ allow_gpu_fallback: bool = (
+ False #: Whether falling back to GPU if DLA cannot support an op should be allowed
+ )
def __init__(self, *args: Any, **kwargs: Any):
"""__init__ Method for torch_tensorrt.Device
Device accepts one of a few construction patterns
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/_Input.py 2024-02-16 00:01:27.175252+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/_Input.py 2024-02-16 00:03:16.328118+00:00
@@ -26,16 +26,16 @@
class _ShapeMode(Enum):
STATIC = 0
DYNAMIC = 1
- shape_mode: Optional[
- _ShapeMode
- ] = None #: Is input statically or dynamically shaped
- shape: Optional[
- Tuple[int, ...] | Dict[str, Tuple[int, ...]]
- ] = None #: Either a single Tuple or a dict of tuples defining the input shape. Static shaped inputs will have a single tuple. Dynamic inputs will have a dict of the form ``{ "min_shape": Tuple, "opt_shape": Tuple, "max_shape": Tuple }``
+ shape_mode: Optional[_ShapeMode] = (
+ None #: Is input statically or dynamically shaped
+ )
+ shape: Optional[Tuple[int, ...] | Dict[str, Tuple[int, ...]]] = (
+ None #: Either a single Tuple or a dict of tuples defining the input shape. Static shaped inputs will have a single tuple. Dynamic inputs will have a dict of the form ``{ "min_shape": Tuple, "opt_shape": Tuple, "max_shape": Tuple }``
+ )
dtype: _enums.dtype = (
_enums.dtype.unknown
) #: The expected data type of the input tensor (default: torch_tensorrt.dtype.float32)
_explicit_set_dtype: bool = False
format: _enums.TensorFormat = (
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/_compiler.py 2024-02-16 00:01:27.175252+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/_compiler.py 2024-02-16 00:03:16.375728+00:00
@@ -212,13 +212,13 @@
"precision": precision,
"debug": debug,
"device": device,
"workspace_size": workspace_size,
"min_block_size": min_block_size,
- "torch_executed_ops": torch_executed_ops
- if torch_executed_ops is not None
- else set(),
+ "torch_executed_ops": (
+ torch_executed_ops if torch_executed_ops is not None else set()
+ ),
"pass_through_build_failures": pass_through_build_failures,
"max_aux_streams": max_aux_streams,
"version_compatible": version_compatible,
"optimization_level": optimization_level,
"use_python_runtime": use_python_runtime,
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py 2024-02-16 00:01:27.175252+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py 2024-02-16 00:03:16.569143+00:00
@@ -26,13 +26,13 @@
from packaging import version
_LOGGER: logging.Logger = logging.getLogger(__name__)
-TRT_INTERPRETER_CALL_PRE_OBSERVER: Observer[
- Callable[[torch.fx.GraphModule], None]
-] = Observer("TRT_INTERPRETER_CALL_PRE_OBSERVER")
+TRT_INTERPRETER_CALL_PRE_OBSERVER: Observer[Callable[[torch.fx.GraphModule], None]] = (
+ Observer("TRT_INTERPRETER_CALL_PRE_OBSERVER")
+)
class UnsupportedOperatorException(RuntimeError):
pass
@@ -90,13 +90,13 @@
self.input_specs_iter = 0
self._cur_node_name: Optional[str] = None
self._cur_node: Optional[torch.fx.Node] = None
self._input_names: List[str] = []
self._output_names: List[str] = []
- self._itensor_to_tensor_meta: Dict[
- trt.tensorrt.ITensor, TensorMetadata
- ] = dict()
+ self._itensor_to_tensor_meta: Dict[trt.tensorrt.ITensor, TensorMetadata] = (
+ dict()
+ )
self.compilation_settings = compilation_settings
# Data types for TRT Module output Tensors
self.output_dtypes = output_dtypes
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/converter_utils.py 2024-02-16 00:01:27.179252+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/converter_utils.py 2024-02-16 00:03:16.647465+00:00
@@ -322,17 +322,15 @@
else:
raise AssertionError(f"Cannot convert {input_val} to TRT constant")
@overload
-def get_positive_dim(dim: int, dim_size: int) -> int:
- ...
+def get_positive_dim(dim: int, dim_size: int) -> int: ...
@overload
-def get_positive_dim(dim: Sequence[int], dim_size: int) -> Tuple[int, ...]:
- ...
+def get_positive_dim(dim: Sequence[int], dim_size: int) -> Tuple[int, ...]: ...
def get_positive_dim(
dim: Union[int, Sequence[int]], dim_size: int
) -> Union[int, Tuple[int, ...]]:
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/lowering/_decomposition_groups.py 2024-02-16 00:01:27.179252+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/lowering/_decomposition_groups.py 2024-02-16 00:03:17.010073+00:00
@@ -5,13 +5,13 @@
from torch._decomp import get_decompositions as get_torch_decompositions
from torch._ops import OpOverload, OpOverloadPacket
aten = torch.ops.aten
-_core_aten_decompositions: Dict[
- OpOverload, Callable[[Any], Any]
-] = core_aten_decompositions()
+_core_aten_decompositions: Dict[OpOverload, Callable[[Any], Any]] = (
+ core_aten_decompositions()
+)
torch_enabled_decompositions: Set[Union[OpOverload, OpOverloadPacket]] = {
aten._adaptive_avg_pool2d_backward,
aten.addcdiv,
aten.addcdiv_,
aten.addcmul,
@@ -179,13 +179,13 @@
torch_disabled_decompositions: Set[Union[OpOverload, OpOverloadPacket]] = {
aten._softmax.default,
}
-ENABLED_TORCH_DECOMPOSITIONS: Dict[
- OpOverload, Callable[[Any], Any]
-] = get_torch_decompositions(torch_enabled_decompositions)
+ENABLED_TORCH_DECOMPOSITIONS: Dict[OpOverload, Callable[[Any], Any]] = (
+ get_torch_decompositions(torch_enabled_decompositions)
+)
TORCH_TRT_DECOMPOSITIONS: Dict[OpOverload, Callable[[Any], Any]] = {}
def check_decomp_set_invariants() -> None:
"""Validates no overlap between enabled and disabled decomposition sets"""
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/lowering/passes/lower_linear.py 2024-02-16 00:01:27.179252+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/lowering/passes/lower_linear.py 2024-02-16 00:03:17.018926+00:00
@@ -20,16 +20,14 @@
logger.debug(f"Graph after lowering linear:\n{gm.graph}")
return gm
-def linear_replacement() -> (
- Tuple[
- torch.fx.GraphModule,
- Callable[[torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor],
- ]
-):
+def linear_replacement() -> Tuple[
+ torch.fx.GraphModule,
+ Callable[[torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor],
+]:
"""Constructs the original and replacement functions for linear"""
# Original graph
def orig(
input: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/lowering/passes/view_to_reshape.py 2024-02-16 00:01:27.179252+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/lowering/passes/view_to_reshape.py 2024-02-16 00:03:17.052927+00:00
@@ -20,16 +20,14 @@
logger.debug(f"Graph after replacing view with reshape:\n{gm.graph}")
return gm
-def view_replacement() -> (
- Tuple[
- torch.fx.GraphModule,
- Callable[[torch.Tensor, List[torch.SymInt]], torch.Tensor],
- ]
-):
+def view_replacement() -> Tuple[
+ torch.fx.GraphModule,
+ Callable[[torch.Tensor, List[torch.SymInt]], torch.Tensor],
+]:
"""Constructs the original and replacement functions for view"""
# Original graph
def orig(input: torch.Tensor, shape: List[torch.SymInt]) -> torch.Tensor:
return torch.ops.aten.view.default(input, shape)
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/lowering/passes/lower_scaled_dot_product_attention.py 2024-02-16 00:01:27.179252+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/lowering/passes/lower_scaled_dot_product_attention.py 2024-02-16 00:03:17.057189+00:00
@@ -58,16 +58,14 @@
logger.debug(f"Graph after lowering scaled dot product attention:\n{gm.graph}")
return gm
-def scaled_dot_product_attention_replacement() -> (
- Tuple[
- Sequence[Callable[[torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor]],
- Callable[[torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor],
- ]
-):
+def scaled_dot_product_attention_replacement() -> Tuple[
+ Sequence[Callable[[torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor]],
+ Callable[[torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor],
+]:
"""Constructs the original and replacement functions for efficient attention"""
# Efficient Attention original graph
def efficient(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor) -> torch.Tensor:
outputs = torch.ops.aten._scaled_dot_product_efficient_attention.default(
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py 2024-02-16 00:01:27.179252+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py 2024-02-16 00:03:17.277499+00:00
@@ -99,25 +99,29 @@
self.engine.get_binding_dtype(idx), Frameworks.TORCH
)
for idx in self.output_binding_indices_in_order
]
self.output_shapes = [
- tuple(self.engine.get_binding_shape(idx))
- if self.engine.has_implicit_batch_dimension
- else tuple()
+ (
+ tuple(self.engine.get_binding_shape(idx))
+ if self.engine.has_implicit_batch_dimension
+ else tuple()
+ )
for idx in self.output_binding_indices_in_order
]
self.hidden_output_dtypes = [
unified_dtype_converter(
self.engine.get_binding_dtype(idx), Frameworks.TORCH
)
for idx in self.hidden_output_binding_indices_in_order
]
self.hidden_output_shapes = [
- tuple(self.engine.get_binding_shape(idx))
- if self.engine.has_implicit_batch_dimension
- else tuple()
+ (
+ tuple(self.engine.get_binding_shape(idx))
+ if self.engine.has_implicit_batch_dimension
+ else tuple()
+ )
for idx in self.hidden_output_binding_indices_in_order
]
def _check_initialized(self) -> None:
if not self.initialized:
@@ -165,13 +169,15 @@
self.__dict__.update(state)
if self.engine:
self.context = self.engine.create_execution_context()
def forward(self, *inputs: torch.Tensor) -> torch.Tensor | Tuple[torch.Tensor, ...]:
- with torch.autograd.profiler.record_function(
- "PythonTorchTensorRTModule:Forward"
- ) if self.profiling_enabled else nullcontext():
+ with (
+ torch.autograd.profiler.record_function("PythonTorchTensorRTModule:Forward")
+ if self.profiling_enabled
+ else nullcontext()
+ ):
self._check_initialized()
# If in safe mode, check at each iteration for for whether a switch is required
if (
torch_tensorrt.runtime.multi_device_safe_mode._PY_RT_MULTI_DEVICE_SAFE_MODE
@@ -198,13 +204,17 @@
torch.cuda.set_device(device_id)
inputs = tuple([tensor.to(device) for tensor in inputs])
logger.warning(f"Moved all input Tensors to cuda:{device_id}")
- with torch.autograd.profiler.record_function(
- "PythonTorchTensorRTModule:ProcessInputs"
- ) if self.profiling_enabled else nullcontext():
+ with (
+ torch.autograd.profiler.record_function(
+ "PythonTorchTensorRTModule:ProcessInputs"
+ )
+ if self.profiling_enabled
+ else nullcontext()
+ ):
assert len(inputs) == len(
self.input_names
), f"Wrong number of inputs, expect {len(self.input_names)} get {len(inputs)}."
contiguous_inputs: List[torch.Tensor] = [i.contiguous() for i in inputs]
@@ -237,13 +247,17 @@
self.context.set_binding_shape(
idx, tuple(contiguous_inputs[i].shape)
)
- with torch.autograd.profiler.record_function(
- "PythonTorchTensorRTModule:ProcessOutputs"
- ) if self.profiling_enabled else nullcontext():
+ with (
+ torch.autograd.profiler.record_function(
+ "PythonTorchTensorRTModule:ProcessOutputs"
+ )
+ if self.profiling_enabled
+ else nullcontext()
+ ):
# create output tensors
outputs: List[torch.Tensor] = []
for i, idx in enumerate(self.output_binding_indices_in_order):
shape = tuple(self.context.get_binding_shape(idx))
@@ -264,13 +278,17 @@
dtype=self.hidden_output_dtypes[i],
device=torch.cuda.current_device(),
)
bindings[idx] = output.data_ptr()
- with torch.autograd.profiler.record_function(
- "PythonTorchTensorRTModule:TensorRTRuntime"
- ) if self.profiling_enabled else nullcontext():
+ with (
+ torch.autograd.profiler.record_function(
+ "PythonTorchTensorRTModule:TensorRTRuntime"
+ )
+ if self.profiling_enabled
+ else nullcontext()
+ ):
self.context.execute_async_v2(
bindings, torch.cuda.current_stream().cuda_stream
)
if len(outputs) == 1:
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/fx/converters/aten_ops_converters.py 2024-02-16 00:01:27.183252+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/fx/converters/aten_ops_converters.py 2024-02-16 00:03:17.622507+00:00
@@ -315,25 +315,21 @@
name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
kwargs_new = {
"input": args[0],
"kernel_size": args[1],
- "stride": args[2]
- if len(args) > 2
- else (None, None)
- if len(args[1]) == 2
- else (None, None, None),
- "padding": args[3]
- if len(args) > 3
- else (0, 0)
- if len(args[1]) == 2
- else (0, 0, 0),
- "dilation": args[4]
- if len(args) > 4
- else (1, 1)
- if len(args[1]) == 2
- else (1, 1, 1),
+ "stride": (
+ args[2]
+ if len(args) > 2
+ else (None, None) if len(args[1]) == 2 else (None, None, None)
+ ),
+ "padding": (
+ args[3] if len(args) > 3 else (0, 0) if len(args[1]) == 2 else (0, 0, 0)
+ ),
+ "dilation": (
+ args[4] if len(args) > 4 else (1, 1) if len(args[1]) == 2 else (1, 1, 1)
+ ),
"ceil_mode": args[5] if len(args) > 5 else False,
}
return acc_ops_converters.acc_ops_max_poolnd(
network, target, None, kwargs_new, name
)
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/fx/fx2trt.py 2024-02-16 00:01:27.183252+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/fx/fx2trt.py 2024-02-16 00:03:17.675873+00:00
@@ -19,13 +19,13 @@
from .observer import Observer
from .utils import get_dynamic_dims, LowerPrecision, unified_dtype_converter, Frameworks
_LOGGER: logging.Logger = logging.getLogger(__name__)
-TRT_INTERPRETER_CALL_PRE_OBSERVER: Observer[
- Callable[[torch.fx.GraphModule], None]
-] = Observer("TRT_INTERPRETER_CALL_PRE_OBSERVER")
+TRT_INTERPRETER_CALL_PRE_OBSERVER: Observer[Callable[[torch.fx.GraphModule], None]] = (
+ Observer("TRT_INTERPRETER_CALL_PRE_OBSERVER")
+)
class TRTInterpreterResult(NamedTuple):
engine: Any
input_names: Sequence[str]
@@ -73,13 +73,13 @@
self.input_specs_iter = 0
self.validate_input_specs()
self._cur_node_name: Optional[str] = None
self._input_names: List[str] = []
self._output_names: List[str] = []
- self._itensor_to_tensor_meta: Dict[
- trt.tensorrt.ITensor, TensorMetadata
- ] = dict()
+ self._itensor_to_tensor_meta: Dict[trt.tensorrt.ITensor, TensorMetadata] = (
+ dict()
+ )
def validate_input_specs(self):
for shape, _, _, shape_ranges, has_batch_dim in self.input_specs:
if not self.network.has_implicit_batch_dimension:
assert (
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/fx/lower.py 2024-02-16 00:01:27.183252+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/fx/lower.py 2024-02-16 00:03:17.684340+00:00
@@ -124,25 +124,29 @@
interpreter = TRTInterpreter(
mod,
input_specs=self.lower_setting.input_specs,
explicit_batch_dimension=self.lower_setting.explicit_batch_dimension,
explicit_precision=self.lower_setting.explicit_precision,
- logger_level=trt.Logger.VERBOSE
- if self.lower_setting.verbose_log
- else trt.Logger.WARNING,
+ logger_level=(
+ trt.Logger.VERBOSE
+ if self.lower_setting.verbose_log
+ else trt.Logger.WARNING
+ ),
)
interp_result: TRTInterpreterResult = interpreter.run(
max_batch_size=self.lower_setting.max_batch_size,
max_workspace_size=self.lower_setting.max_workspace_size,
lower_precision=self.lower_setting.lower_precision,
strict_type_constraints=self.lower_setting.strict_type_constraints,
algorithm_selector=algo_selector,
timing_cache=cache_data,
- profiling_verbosity=trt.ProfilingVerbosity.DETAILED
- if self.lower_setting.verbose_profile
- else trt.ProfilingVerbosity.LAYER_NAMES_ONLY,
+ profiling_verbosity=(
+ trt.ProfilingVerbosity.DETAILED
+ if self.lower_setting.verbose_profile
+ else trt.ProfilingVerbosity.LAYER_NAMES_ONLY
+ ),
tactic_sources=self.lower_setting.tactic_sources,
)
# Update timing cache file if needed
timing_cache = interp_result.serialized_cache
@@ -295,14 +299,12 @@
module.half()
# A custom conversion function can be passed to the lowerer to
# handle inputs with custom types. By default, just handle
# tensors and NoneType.
if fp16_conversion_fn is None:
- conversion_fn = (
- lambda x: x.half()
- if x is not None and x.dtype == torch.float32
- else x
+ conversion_fn = lambda x: (
+ x.half() if x is not None and x.dtype == torch.float32 else x
)
else:
conversion_fn = fp16_conversion_fn
inputs = tuple(conversion_fn(x) for x in inputs)
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/fx/passes/lower_pass_manager_builder.py 2024-02-16 00:01:27.183252+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/fx/passes/lower_pass_manager_builder.py 2024-02-16 00:03:17.896029+00:00
@@ -194,13 +194,15 @@
lowering_start_time = datetime.datetime.now()
self.lower_setting.input_specs = generate_input_specs(
submod_inputs,
self.lower_setting,
- additional_submodule_inputs[submod_name]
- if additional_submodule_inputs
- else None,
+ (
+ additional_submodule_inputs[submod_name]
+ if additional_submodule_inputs
+ else None
+ ),
)
lowered_module = self._lower_func(
submod, submod_inputs, self.lower_setting, submod_name
)
setattr(split_result.split_module, submod_name, lowered_module)
@@ -234,13 +236,15 @@
if not submod_name.startswith(split_result.non_acc_submodule_prefix):
_LOGGER.info(f"ACC submodule graph: {submod.graph}")
lowering_start_time = datetime.datetime.now()
self.lower_setting.additional_inputs = (
- additional_submodule_inputs[submod_name]
- if additional_submodule_inputs
- else None,
+ (
+ additional_submodule_inputs[submod_name]
+ if additional_submodule_inputs
+ else None
+ ),
)
lowered_module = self._lower_func(
submod, submod_inputs, self.lower_setting, submod_name
)
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/fx/passes/pass_utils.py 2024-02-16 00:01:27.183252+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/fx/passes/pass_utils.py 2024-02-16 00:03:18.123580+00:00
@@ -193,13 +193,11 @@
kwargs2 = {"equal_nan": True}
if rtol:
kwargs2["rtol"] = rtol
if atol:
kwargs2["atol"] = atol
- kwargs2[
- "msg"
- ] = (
+ kwargs2["msg"] = (
lambda msg: f"Pass {pass_} failed correctness check due at output {kk}:\n{msg}"
)
# If tensors are on different devices, make sure to compare
# their copies that are on the same device.
if x.get_device() != y.get_device():
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/fx/passes/lower_basic_pass.py 2024-02-16 00:01:27.183252+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/fx/passes/lower_basic_pass.py 2024-02-16 00:03:18.166433+00:00
@@ -536,13 +536,13 @@
reshape_batch_size: Optional[fx.Node] = get_reshape_batch_size_as_node(
maybe_reshape
)
if not reshape_batch_size:
continue
- reshape_batch_size_inferred_source: Optional[
- fx.Node
- ] = get_reshape_batch_size_inferred_source(reshape_batch_size)
+ reshape_batch_size_inferred_source: Optional[fx.Node] = (
+ get_reshape_batch_size_inferred_source(reshape_batch_size)
+ )
if not reshape_batch_size_inferred_source:
continue
reshape_input: fx.Node = maybe_reshape.kwargs["input"]
if reshape_input == reshape_batch_size_inferred_source:
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/fx/test/converters/acc_op/test_split.py 2024-02-16 00:01:27.187252+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/fx/test/converters/acc_op/test_split.py 2024-02-16 00:03:18.592970+00:00
@@ -21,13 +21,15 @@
inputs = [torch.randn(1, 10)]
self.run_test(
Split(),
inputs,
expected_ops={
- acc_ops.split
- if isinstance(split_size_or_sections, int)
- else acc_ops.slice_tensor
+ (
+ acc_ops.split
+ if isinstance(split_size_or_sections, int)
+ else acc_ops.slice_tensor
+ )
},
test_explicit_batch_dim=False,
)
@parameterized.expand(
@@ -68,13 +70,15 @@
]
self.run_test_with_dynamic_shape(
Split(),
input_specs,
expected_ops={
- acc_ops.split
- if isinstance(split_size_or_sections, int)
- else acc_ops.slice_tensor
+ (
+ acc_ops.split
+ if isinstance(split_size_or_sections, int)
+ else acc_ops.slice_tensor
+ )
},
)
# Testing with (-1, -1, -1) results into following error:
# AssertionError: Can't chunk on dynamic shape dimension!
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/fx/tools/common_fx2trt.py 2024-02-16 00:01:27.191252+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/fx/tools/common_fx2trt.py 2024-02-16 00:03:19.259154+00:00
@@ -152,13 +152,13 @@
mod.eval()
if len(expected_ops):
self.assert_has_op(mod, expected_ops)
interpreter_result = interpreter.run(
- lower_precision=LowerPrecision.FP16
- if fp16_mode
- else LowerPrecision.FP32
+ lower_precision=(
+ LowerPrecision.FP16 if fp16_mode else LowerPrecision.FP32
+ )
)
trt_mod = TRTModule(
interpreter_result.engine,
interpreter_result.input_names,
interpreter_result.output_names,
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/fx/trt_module.py 2024-02-16 00:01:27.191252+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/fx/trt_module.py 2024-02-16 00:03:19.609670+00:00
@@ -67,25 +67,29 @@
self.engine.get_binding_dtype(idx), Frameworks.TORCH
)
for idx in self.output_binding_indices_in_order
]
self.output_shapes = [
- tuple(self.engine.get_binding_shape(idx))
- if self.engine.has_implicit_batch_dimension
- else tuple()
+ (
+ tuple(self.engine.get_binding_shape(idx))
+ if self.engine.has_implicit_batch_dimension
+ else tuple()
+ )
for idx in self.output_binding_indices_in_order
]
self.hidden_output_dtypes: Sequence[torch.dtype] = [
unified_dtype_converter(
self.engine.get_binding_dtype(idx), Frameworks.TORCH
)
for idx in self.hidden_output_binding_indices_in_order
]
self.hidden_output_shapes = [
- tuple(self.engine.get_binding_shape(idx))
- if self.engine.has_implicit_batch_dimension
- else tuple()
+ (
+ tuple(self.engine.get_binding_shape(idx))
+ if self.engine.has_implicit_batch_dimension
+ else tuple()
+ )
for idx in self.hidden_output_binding_indices_in_order
]
def _check_initialized(self):
if not self.initialized:
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/ts/_compile_spec.py 2024-02-16 00:01:27.191252+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/ts/_compile_spec.py 2024-02-16 00:03:19.911816+00:00
@@ -404,13 +404,13 @@
"inputs": inputs if inputs is not None else [],
# "input_signature": input_signature,
"device": device,
"disable_tf32": disable_tf32, # Force FP32 layers to use traditional as FP32 format vs the default behavior of rounding the inputs to 10-bit mantissas before multiplying, but accumulates the sum using 23-bit mantissas
"sparse_weights": sparse_weights, # Enable sparsity for convolution and fully connected layers.
- "enabled_precisions": enabled_precisions
- if enabled_precisions is not None
- else set(), # Enabling FP16 kernels
+ "enabled_precisions": (
+ enabled_precisions if enabled_precisions is not None else set()
+ ), # Enabling FP16 kernels
"refit": refit, # enable refit
"debug": debug, # enable debuggable engine
"capability": capability, # Restrict kernel selection to safe gpu kernels or safe dla kernels
"num_avg_timing_iters": num_avg_timing_iters, # Number of averaging timing iterations used to select kernels
"workspace_size": workspace_size, # Maximum size of workspace given to TensorRT
150f055
to
f1ff596
Compare
2eaae77
to
b5b45a1
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Overall, looks good - just had one question, added below
unbind_tensors = torch.unbind(input_tensor, dim) | ||
unbind_tensors_list = list(unbind_tensors) | ||
unbind_tensors_list[index] = src_tensor | ||
return torch.stack(tuple(unbind_tensors_list), dim) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What operators in the graph does this generate after tracing? Is there a before/after sample that could be shared
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@gs-olive, these were the graphs-
Pre-AOT Autograd graph:
graph():
%l_x_ : torch.Tensor [num_users=1] = placeholder[target=L_x_]
%l_src_ : torch.Tensor [num_users=1] = placeholder[target=L_src_]
%select_scatter_default : [num_users=1] = call_function[target=torch.ops.aten.select_scatter.default](args = (%l_x_, %l_src_, 0, 0), kwargs = {})
return (select_scatter_default,)
Post-AOT Autograd graph:
graph():
%arg0_1 : [num_users=1] = placeholder[target=arg0_1]
%arg1_1 : [num_users=1] = placeholder[target=arg1_1]
%clone : [num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%arg1_1,), kwargs = {})
%clone_1 : [num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%arg0_1,), kwargs = {})
%slice_2 : [num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%clone_1, 0, 1, 2), kwargs = {})
%squeeze_1 : [num_users=1] = call_function[target=torch.ops.aten.squeeze.dim](args = (%slice_2, 0), kwargs = {})
%cat : [num_users=1] = call_function[target=torch.ops.aten.cat.default](args = ([%clone, %squeeze_1],), kwargs = {})
%view : [num_users=1] = call_function[target=torch.ops.aten.view.default](args = (%cat, [2, 2]), kwargs = {})
return (view,)
Graph after constant folding:
graph():
%arg0_1 : [num_users=1] = placeholder[target=arg0_1]
%arg1_1 : [num_users=1] = placeholder[target=arg1_1]
%slice_2 : [num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%arg0_1, 0, 1, 2), kwargs = {})
%squeeze_1 : [num_users=1] = call_function[target=torch.ops.aten.squeeze.dim](args = (%slice_2, 0), kwargs = {})
%cat : [num_users=1] = call_function[target=torch.ops.aten.cat.default](args = ([%arg1_1, %squeeze_1],), kwargs = {})
%view : [num_users=1] = call_function[target=torch.ops.aten.view.default](args = (%cat, [2, 2]), kwargs = {})
return (view,)
Post-lowering passes Autograd graph:
graph():
%arg0_1 : [num_users=1] = placeholder[target=arg0_1]
%arg1_1 : [num_users=1] = placeholder[target=arg1_1]
%slice_2 : [num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%arg0_1, 0, 1, 2), kwargs = {})
%squeeze_1 : [num_users=1] = call_function[target=torch.ops.aten.squeeze.dim](args = (%slice_2, 0), kwargs = {})
%cat : [num_users=1] = call_function[target=torch.ops.aten.cat.default](args = ([%arg1_1, %squeeze_1],), kwargs = {})
%reshape_default : [num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%cat, [2, 2]), kwargs = {})
return (reshape_default,)
However I have changed the implementation now to make use of slice_scatter
implementation which I have updated in the description.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think the old implementation was valid, but the new one does not seem to work in some cases, for example:
>>> import torch
>>> a = torch.zeros(2, 2)
>>> b = torch.ones(2)
>>> torch.select_scatter(a, b, 0, 0)
tensor([[1., 1.],
[0., 0.]])
>>> torch.slice_scatter(a, b.unsqueeze(0), 0, 1, 1)
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
RuntimeError: expected src to have a size equal to the slice of self. src size = [1, 2], slice size = [0, 2]
See this decomposition for an alternative approach. |
Thanks @gs-olive for pointing the above. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There are some changes that do not conform to Python style guidelines:
--- /home/runner/work/TensorRT/TensorRT/tests/py/dynamo/lowering/test_decompositions.py 2024-03-26 20:46:17.748006+00:00
+++ /home/runner/work/TensorRT/TensorRT/tests/py/dynamo/lowering/test_decompositions.py 2024-03-26 20:53:01.432160+00:00
@@ -607,7 +607,8 @@
0,
DECIMALS_OF_AGREEMENT,
f"Select_scatter TRT outputs don't match with the original model.",
)
+
if __name__ == "__main__":
run_tests()
So, in this case would the implementation not be functional without the Additionally, if the |
I misread the case pointed by you.
In the above according to the implementation above, the
To answer the above question-
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good to me, pending rebase + CI passing
dc9670f
to
694befd
Compare
da4d71d
to
f174eb1
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There are some changes that do not conform to Python style guidelines:
--- /home/runner/work/TensorRT/TensorRT/tests/py/dynamo/lowering/test_decompositions.py 2024-05-30 17:34:01.396563+00:00
+++ /home/runner/work/TensorRT/TensorRT/tests/py/dynamo/lowering/test_decompositions.py 2024-05-30 17:35:56.410082+00:00
@@ -669,10 +669,10 @@
self.assertAlmostEqual(
max_diff,
0,
DECIMALS_OF_AGREEMENT,
f"Select_scatter TRT outputs don't match with the original model.",
- )
+ )
if __name__ == "__main__":
run_tests()
ee8330a
to
004c56f
Compare
Changing lowering of select_scatter select_scatter changes select_scatter changes Test case for select_scatter removing assertion adding select_scatter decomp lowering ops in test implement select_scatter using slice_scatter adding test case linting commit fix
Fixes #2436
This PR would be dependant on #2519, #2664 and #2669. Major changes
2519- Decomposition of aten::slice_scatter
2664- Implementation makes use of aten::scatter.src
2669- Constants getting converted to fake tensors in get_attr call due to which different device location meta and cpu in torch