-
Notifications
You must be signed in to change notification settings - Fork 356
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
chore: Upgrade TensorRT version to TRT 10 EA #2699
Conversation
chore: updates to trt api chore: trt 10 fixes chore: more fixes
ff95381
to
8980b7a
Compare
author Dheeraj Peri <peri.dheeraj@gmail.com> 1711393059 -0700 committer Dheeraj Peri <peri.dheeraj@gmail.com> 1711393072 -0700 chore: minor updates chore: Fix save failures chore: minor fixes chore: remove duplicate bert test case chore: remove comments chore: add load api chore: minor updates chore: minor updates chore: minor updates chore: more updates
# self.hidden_output_names: Sequence[str] = [] | ||
# for i in range( | ||
# self.engine.num_bindings // self.engine.num_optimization_profiles | ||
# ): | ||
# if i not in primary_input_outputs: | ||
# self.hidden_output_binding_indices_in_order.append(i) | ||
# self.hidden_output_names.append(self.engine.get_binding_name(i)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
review needed
len(self.input_names) | ||
+ len(self.output_names) | ||
+ len(self.hidden_output_names) | ||
# + len(self.hidden_output_names) #TODO: Verify if this is required |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
review needed
# TODO: Verify what this is for ? | ||
# self.hidden_output_dtypes = [ | ||
# unified_dtype_converter( | ||
# self.engine.get_binding_dtype(idx), Frameworks.TORCH | ||
# ) | ||
# for idx in self.hidden_output_binding_indices_in_order | ||
# ] | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
review needed
# TODO: Check what is this for ? | ||
# for i, idx in enumerate(self.hidden_output_binding_indices_in_order): | ||
# shape = tuple(self.context.get_binding_shape(idx)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
review needed
@@ -107,7 +107,7 @@ def _from( | |||
return dtype.f16 | |||
elif t == trt.float32: | |||
return dtype.f32 | |||
elif trt.__version__ >= "7.0" and t == trt.bool: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why are you removing the version check?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This version check actually fails. If trt.__version == 10.0.0b6, then trt.__version__ >= "7.0"
is False (which should actually be true) and hence the trt.bool type wouldn't be returned which results in type errors.
We also don't need these version checks because we only support strict TRT versions.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Overall looks great - added a comment in the TRTEngine
and wanted to note that the Docker build + Docker build jobs will likely fail since the TRT installation will need upgrading for the build, but that can be a separate PR.
Yes, I'm working on docker updates in fp8_trt branch. So that will likely get merged with this PR #2763 |
Co-authored-by: Evan Li <zewenl@nvidia.com>
Co-authored-by: Evan Li <zewenl@nvidia.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Need some patches as comments.
@@ -312,7 +313,7 @@ def run( | |||
) | |||
timing_cache = self._create_timing_cache(builder_config, existing_cache) | |||
|
|||
engine = self.builder.build_engine(self.ctx.net, builder_config) | |||
engine = self.builder.build_serialized_network(self.ctx.net, builder_config) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
build_serialized_network
returns a plan
instead of engine
, which is different from previous TRT. In TRT-10, we can do two steps to get an engine
:
plan = builder.build_serialized_network(network, config)
engine = runtime.deserialize_cuda_engine(plan)
If it is on purpose to get plan
, maybe we can change the name to make it easy understand.
assert ( | ||
self.engine.num_io_tensors // self.engine.num_optimization_profiles | ||
) == (len(self.input_names) + len(self.output_names)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The assertion should be:
assert self.engine.num_io_tensors == (len(self.input_names) + len(self.output_names))
A bug was reported here: #2811
Description
Upgrade TensorRT version to TRT 10 EA
Type of change
Please delete options that are not relevant and/or add your own.
Checklist: