-
Notifications
You must be signed in to change notification settings - Fork 355
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Implemented basic pipeline for Refitting #2886
Conversation
|
||
|
||
def refit_trt_engine_from_module( | ||
exported_program: ExportedProgram, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Remove the settings that dont do anything for refit
def refit_trt_engine_from_module( | ||
exported_program: ExportedProgram, | ||
inputs: Tuple[Any, ...], | ||
engine: object, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Eventually will become the compiled exported program
@@ -609,3 +610,126 @@ def convert_module_to_trt_engine( | |||
engine_bytearray = engine_bytes.getvalue() | |||
|
|||
return engine_bytearray | |||
|
|||
|
|||
def refit_trt_engine_from_module( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Eventually something like
def refit_module_weights(
compiled_module: ExportedProgram,
new_weight_module: ExportedProgram
) -> torch.fx.GraphModule:
|
||
enabled_precisions = {dtype._from(e) for e in enabled_precisions} | ||
|
||
compilation_options = { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can store the compilation settings as metadata in the returned graph (then we can just read the compiled program to fill these settings in to match the original lowering)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ask Dheeraj where to put the meta data for lowering
|
||
mapping = get_refit_mapping(gm, input_list, settings) | ||
|
||
TRT_LOGGER = trt.Logger(trt.Logger.ERROR) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Lets move this stuff into a submodule or other file (torch_tensorrt/dynamo/_refit.py
)
|
||
mapping = get_refit_mapping(gm, input_list, settings) | ||
|
||
TRT_LOGGER = trt.Logger(trt.Logger.ERROR) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Reuse global logger
TensorRT/py/torch_tensorrt/logging.py
Line 33 in 3422c41
TRT_LOGGER = _TRTLogger() |
@@ -88,6 +89,61 @@ def interpret_module_to_result( | |||
return interpreter_result | |||
|
|||
|
|||
def get_refit_mapping( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Move this to the refit file
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
def construct_refit_weight_mapping(
new_weights_mod: torch.fx.GraphModule,
compile_settings: CompilationSettings # Info from the metadata of the compiled module
):
serialized_engine, self._input_names, self._output_names, serialized_cache | ||
) | ||
|
||
def get_network_to_refit( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lets call this something like
def _construct_trt_network_def()
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The user would do something like
interpreter._construct_trt_network_def()
net = interpreter.ctx.net
1 AI: take
Replace all uses of cc: @peri044 |
…efit. Support setting loading
py/torch_tensorrt/dynamo/_refit.py
Outdated
compiled_module = copy.copy(compiled_module) | ||
# Iterate over all components that can be accelerated | ||
# Generate the corresponding TRT Module for those | ||
for name, _ in partitioned_module.named_children(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We need to ensure that the new module's partition is the same as the compiled modules partition.
- We can check the number of subgraphs, perhaps the names as well (if deterministic)
- At compile time, compute the hash of the source fx graph (https://github.com/pytorch/pytorch/blob/fba21edf5b9aa14babb9c0bc860dc9c597eb8010/torch/_inductor/codecache.py#L670) and store as attribute in the TRTModule. The compare the hash of the new graph to the one stored in the compiled subgraph module
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@zewenli98 You might be interested in reusing this part for engine caching
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM.
Description
Please include a summary of the change and which issue is fixed. Please also include relevant motivation and context. List any dependencies that are required for this change.
Fixes # (issue)
Type of change
Please delete options that are not relevant and/or add your own.
Checklist: