From 2518db5d547a5533bba56a582aae526e788a06b3 Mon Sep 17 00:00:00 2001 From: Naren Dasan Date: Wed, 11 Sep 2024 08:07:19 -0600 Subject: [PATCH] fix: distingush engines based on compilation settings in addition to graph structure Signed-off-by: Naren Dasan --- py/torch_tensorrt/_Input.py | 28 +++++++++++++++++++ .../dynamo/conversion/_TRTInterpreter.py | 12 +++++++- 2 files changed, 39 insertions(+), 1 deletion(-) diff --git a/py/torch_tensorrt/_Input.py b/py/torch_tensorrt/_Input.py index 72775944cb..126219ee8a 100644 --- a/py/torch_tensorrt/_Input.py +++ b/py/torch_tensorrt/_Input.py @@ -220,6 +220,34 @@ def __str__(self) -> str: def __repr__(self) -> str: return self.__str__() + @staticmethod + def equivalent_spec(a: Input, b: Input) -> bool: + if a.shape_mode != b.shape_mode: + return False + + if a.shape_mode == Input._ShapeMode.DYNAMIC: + assert isinstance(a.shape, dict) + assert isinstance(b.shape, dict) + checks = [ + a.shape["min_shape"] == b.shape["min_shape"], + a.shape["opt_shape"] == b.shape["opt_shape"], + a.shape["max_shape"] == b.shape["max_shape"], + a.dtype == b.dtype, + a.format == b.format, + a.low_tensor_domain_incl == b.low_tensor_domain_incl, + a.high_tensor_domain_excl == b.high_tensor_domain_excl, + ] + return all(checks) + else: + checks = [ + a.shape == b.shape, + a.dtype == b.dtype, + a.format == b.format, + a.low_tensor_domain_incl == b.low_tensor_domain_incl, + a.high_tensor_domain_excl == b.high_tensor_domain_excl, + ] + return all(checks) + @staticmethod def _supported_input_size_type(input_size: Any) -> bool: if isinstance(input_size, torch.Size): diff --git a/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py b/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py index 48306522b1..f1b68b5436 100644 --- a/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py +++ b/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py @@ -545,7 +545,7 @@ def run( serialized_engine, self._input_names, self._output_names, - engine_input_specs, + cached_engine_input_specs, engine_compilation_settings, self.weight_name_map, ) = cached_data @@ -559,6 +559,16 @@ def run( setting_compatiblity ), f"Attempted to refit a prebuilt engine with incompatible settings: {incompattible_settings}, (old_settings: {engine_compilation_settings}, new_settings: {self.compilation_settings})" + for i, e in enumerate( + [ + Input.equivalent_spec(c, i) + for c, i in zip(cached_engine_input_specs, self.input_specs) + ] + ): + assert ( + e + ), f"Found that cached engine was built for a different input size (input: {i}, cached size: {cached_engine_input_specs[i]}, new size: {self.input_specs[i]}" + _LOGGER.info( "Found the cached engine that corresponds to this graph. It is directly loaded." )