Skip to content

Commit

Permalink
fix: distingush engines based on compilation settings in addition to …
Browse files Browse the repository at this point in the history
…graph structure

Signed-off-by: Naren Dasan <naren@narendasan.com>
  • Loading branch information
narendasan committed Sep 11, 2024
1 parent f84be56 commit 2518db5
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 1 deletion.
28 changes: 28 additions & 0 deletions py/torch_tensorrt/_Input.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,34 @@ def __str__(self) -> str:
def __repr__(self) -> str:
return self.__str__()

@staticmethod
def equivalent_spec(a: Input, b: Input) -> bool:
if a.shape_mode != b.shape_mode:
return False

if a.shape_mode == Input._ShapeMode.DYNAMIC:
assert isinstance(a.shape, dict)
assert isinstance(b.shape, dict)
checks = [
a.shape["min_shape"] == b.shape["min_shape"],
a.shape["opt_shape"] == b.shape["opt_shape"],
a.shape["max_shape"] == b.shape["max_shape"],
a.dtype == b.dtype,
a.format == b.format,
a.low_tensor_domain_incl == b.low_tensor_domain_incl,
a.high_tensor_domain_excl == b.high_tensor_domain_excl,
]
return all(checks)
else:
checks = [
a.shape == b.shape,
a.dtype == b.dtype,
a.format == b.format,
a.low_tensor_domain_incl == b.low_tensor_domain_incl,
a.high_tensor_domain_excl == b.high_tensor_domain_excl,
]
return all(checks)

@staticmethod
def _supported_input_size_type(input_size: Any) -> bool:
if isinstance(input_size, torch.Size):
Expand Down
12 changes: 11 additions & 1 deletion py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -545,7 +545,7 @@ def run(
serialized_engine,
self._input_names,
self._output_names,
engine_input_specs,
cached_engine_input_specs,
engine_compilation_settings,
self.weight_name_map,
) = cached_data
Expand All @@ -559,6 +559,16 @@ def run(
setting_compatiblity
), f"Attempted to refit a prebuilt engine with incompatible settings: {incompattible_settings}, (old_settings: {engine_compilation_settings}, new_settings: {self.compilation_settings})"

for i, e in enumerate(
[
Input.equivalent_spec(c, i)
for c, i in zip(cached_engine_input_specs, self.input_specs)
]
):
assert (
e
), f"Found that cached engine was built for a different input size (input: {i}, cached size: {cached_engine_input_specs[i]}, new size: {self.input_specs[i]}"

_LOGGER.info(
"Found the cached engine that corresponds to this graph. It is directly loaded."
)
Expand Down

0 comments on commit 2518db5

Please sign in to comment.