Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Overhaul upsample dynamo converter #2790

Merged
merged 1 commit into from
Jul 31, 2024
Merged

Conversation

HolyWu
Copy link
Contributor

@HolyWu HolyWu commented Apr 27, 2024

Description

  • Add missing converters for nearest1d, nearest3d, linear1d, trilinear3d and bicubic2d operators.
  • Fix incorrect align_corners argument handling.
  • Remove upsample_bilinear2d from torch_enabled_decompositions.
  • Add support for dynamic shapes.
  • Override PyTorch's CompositeImplicitAutograd dispatch key so users of torch.nn.functional.interpolate can actually use these converters rather than get decomposed operators.

Fixes #2680

Type of change

  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)

Checklist:

  • My code follows the style guidelines of this project (You can use the linters)
  • I have performed a self-review of my own code
  • I have commented my code, particularly in hard-to-understand areas and hacks
  • I have made corresponding changes to the documentation
  • I have added tests to verify my fix or my feature
  • New and existing unit tests pass locally with my changes
  • I have added the relevant labels to my PR in so that relevant reviewers are notified

@github-actions github-actions bot added component: tests Issues re: Tests component: lowering Issues re: The lowering / preprocessing passes component: conversion Issues re: Conversion stage component: converters Issues re: Specific op converters component: api [Python] Issues re: Python API component: dynamo Issues relating to the `torch.compile` or `torch._dynamo.export` paths labels Apr 27, 2024
@github-actions github-actions bot requested a review from apbose April 27, 2024 10:23
@HolyWu HolyWu force-pushed the fix_upsample branch 2 times, most recently from 91ff218 to 5b8a24d Compare April 30, 2024 14:06
@HolyWu HolyWu force-pushed the fix_upsample branch 2 times, most recently from 6a511f6 to 0d3f9a0 Compare June 8, 2024 04:24
@HolyWu
Copy link
Contributor Author

HolyWu commented Jun 8, 2024

Hi @apbose. I have rewritten the converter registries. Now .default and .vec are registered separately as their schemas are different. Test cases for .default are also added.

Note that the optional scales* arguments in .default are intentionally not used in our converter, since the output_size argument is not optional and PyTorch has already computed the final output size which can be used directly, as can be seen in the following test script and log.

import torch
import torch_tensorrt


class MyModule(torch.nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, x):
        return torch.nn.functional.interpolate(x, scale_factor=(1.5, 2.5), mode="bicubic")


device = torch.device("cuda", 0)
model = MyModule().eval().to(device)
inputs = [torch.rand((1, 3, 6, 9), dtype=torch.float, device=device)]

with torch.inference_mode():
    optimized_model = torch_tensorrt.compile(
        model,
        ir="torch_compile",
        inputs=inputs,
        enabled_precisions={torch.float},
        debug=True,
        min_block_size=1,
        device=device,
    )

    optimized_model(*inputs)
WARNING:torch_tensorrt.dynamo.conversion.aten_ops_converters:Unable to import quantization op. Please install modelopt library (https://github.com/NVIDIA/TensorRT-Model-Optimizer?tab=readme-ov-file#installation) to add support for compiling quantized models
INFO:torch_tensorrt.dynamo.utils:Using Default Torch-TRT Runtime (as requested by user)
INFO:torch_tensorrt.dynamo.utils:Compilation Settings: CompilationSettings(enabled_precisions={<dtype.f32: 7>}, debug=True, workspace_size=0, min_block_size=1, torch_executed_ops=set(), pass_through_build_failures=False, max_aux_streams=None, version_compatible=False, optimization_level=None, use_python_runtime=False, truncate_double=False, use_fast_partitioner=True, enable_experimental_decompositions=False, device=Device(type=DeviceType.GPU, gpu_id=0), require_full_compilation=False, disable_tf32=False, assume_dynamic_shape_support=False, sparse_weights=False, refit=False, engine_capability=<EngineCapability.STANDARD: 1>, num_avg_timing_iters=1, dla_sram_size=1048576, dla_local_dram_size=1073741824, dla_global_dram_size=536870912, dryrun=False, hardware_compatible=False)

DEBUG:torch_tensorrt.dynamo.backend.backends:Pre-AOT Autograd graph:
graph():
    %l_x_ : torch.Tensor [num_users=1] = placeholder[target=L_x_]
    %interpolate : [num_users=1] = call_function[target=torch.nn.functional.interpolate](args = (%l_x_,), kwargs = {scale_factor: (1.5, 2.5), mode: bicubic})
    return (interpolate,)
DEBUG:torch_tensorrt.dynamo.lowering._repair_input_aliasing:Inserted auxiliary clone nodes for placeholders:
graph():
    %l_x_ : torch.Tensor [num_users=1] = placeholder[target=L_x_]
    %clone_default : [num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%l_x_,), kwargs = {})
    %interpolate : [num_users=1] = call_function[target=torch.nn.functional.interpolate](args = (%clone_default,), kwargs = {scale_factor: (1.5, 2.5), mode: bicubic})
    return (interpolate,)
DEBUG:torch_tensorrt.dynamo.lowering._remove_sym_nodes:Removed SymInt placeholders:
graph():
    %l_x_ : torch.Tensor [num_users=1] = placeholder[target=L_x_]
    %clone_default : [num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%l_x_,), kwargs = {})
    %interpolate : [num_users=1] = call_function[target=torch.nn.functional.interpolate](args = (%clone_default,), kwargs = {scale_factor: (1.5, 2.5), mode: bicubic})
    return (interpolate,)
DEBUG:torch_tensorrt.dynamo.lowering.passes.remove_detach:Removed 0 detach nodes:
graph():
    %l_x_ : torch.Tensor [num_users=1] = placeholder[target=L_x_]
    %clone_default : [num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%l_x_,), kwargs = {})
    %interpolate : [num_users=1] = call_function[target=torch.nn.functional.interpolate](args = (%clone_default,), kwargs = {scale_factor: (1.5, 2.5), mode: bicubic})
    return (interpolate,)
DEBUG:torch_tensorrt.dynamo.backend.backends:Post-AOT Autograd graph:
graph():
    %arg0_1 : [num_users=1] = placeholder[target=arg0_1]
    %clone : [num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%arg0_1,), kwargs = {})
    %upsample_bicubic2d : [num_users=1] = call_function[target=torch.ops.aten.upsample_bicubic2d.default](args = (%clone, [9, 22], False, 1.5, 2.5), kwargs = {})
    return (upsample_bicubic2d,)
DEBUG:torch_tensorrt.dynamo.lowering.passes.remove_input_alias_fixing_clones:Removing node clone from graph, since it is a clone node which is the only user of placeholder arg0_1 and was inserted by the compiler.
DEBUG:torch_tensorrt.dynamo.lowering.passes.remove_input_alias_fixing_clones:Removed auxiliary clone nodes for placeholders:
graph():
    %arg0_1 : [num_users=1] = placeholder[target=arg0_1]
    %upsample_bicubic2d : [num_users=1] = call_function[target=torch.ops.aten.upsample_bicubic2d.default](args = (%arg0_1, [9, 22], False, 1.5, 2.5), kwargs = {})
    return (upsample_bicubic2d,)
DEBUG:torch_tensorrt.dynamo.lowering.passes.constant_folding:Graph after constant folding:
graph():
    %arg0_1 : [num_users=1] = placeholder[target=arg0_1]
    %upsample_bicubic2d : [num_users=1] = call_function[target=torch.ops.aten.upsample_bicubic2d.default](args = (%arg0_1, [9, 22], False, 1.5, 2.5), kwargs = {})
    return (upsample_bicubic2d,)
DEBUG:torch_tensorrt.dynamo.backend.backends:Lowered Input graph:
 graph():
    %arg0_1 : [num_users=1] = placeholder[target=arg0_1]
    %upsample_bicubic2d : [num_users=1] = call_function[target=torch.ops.aten.upsample_bicubic2d.default](args = (%arg0_1, [9, 22], False, 1.5, 2.5), kwargs = {})
    return (upsample_bicubic2d,)
DEBUG:torch_tensorrt.dynamo.partitioning._global_partitioner:
Supported Nodes:
- torch.ops.aten.upsample_bicubic2d.default + Operator Count: 1

DEBUG:torch_tensorrt.dynamo.partitioning._global_partitioner:
All Nodes Supported

DEBUG:torch_tensorrt.dynamo._compiler:Detected support for 1 operators out of 1 in subgraph.
DEBUG:torch_tensorrt.dynamo.partitioning._adjacency_partitioner:
Number of TensorRT-Accelerated Engines Generated: 1
DEBUG:torch_tensorrt.dynamo.partitioning._adjacency_partitioner:
Supported Nodes:
- torch.ops.aten.upsample_bicubic2d.default + Operator Count: 1

DEBUG:torch_tensorrt.dynamo.partitioning._adjacency_partitioner:
All Nodes Supported

DEBUG:torch_tensorrt.dynamo._compiler:Submodule name: _run_on_acc_0
 Input shapes: [(1, 3, 6, 9)]
 graph():
    %arg0_1 : [num_users=1] = placeholder[target=arg0_1]
    %upsample_bicubic2d : [num_users=1] = call_function[target=torch.ops.aten.upsample_bicubic2d.default](args = (%arg0_1, [9, 22], False, 1.5, 2.5), kwargs = {})
    return upsample_bicubic2d
INFO:torch_tensorrt [TensorRT Conversion Context]:[MemUsageChange] Init CUDA: CPU +2, GPU +0, now: CPU 12954, GPU 1009 (MiB)
INFO:torch_tensorrt [TensorRT Conversion Context]:[MemUsageChange] Init builder kernel library: CPU +2637, GPU +308, now: CPU 15854, GPU 1317 (MiB)
DEBUG:torch_tensorrt.dynamo.conversion._TRTInterpreter:Adding input to in-progress INetwork: arg0_1 [shape=[1, 3, 6, 9], dtype=DataType.FLOAT]
DEBUG:torch_tensorrt.dynamo.conversion._TRTInterpreter:Converting node upsample_bicubic2d (kind: aten.upsample_bicubic2d.default, args: ('arg0_1 <tensorrt.ITensor [shape=(1, 3, 6, 9), dtype=DataType.FLOAT]>', [9, 22], False, 1.5, 2.5))
DEBUG:torch_tensorrt.dynamo.conversion._TRTInterpreter:Marking output output0 [shape=(1, 3, 9, 22), dtype=DataType.FLOAT]
INFO:torch_tensorrt.dynamo.conversion._TRTInterpreter:TRT INetwork construction elapsed time: 0:00:00.000977
INFO:torch_tensorrt [TensorRT Conversion Context]:Global timing cache in use. Profiling results in this builder pass will be stored.
INFO:torch_tensorrt [TensorRT Conversion Context]:Detected 1 inputs and 1 output network tensors.
INFO:torch_tensorrt [TensorRT Conversion Context]:Total Host Persistent Memory: 0
INFO:torch_tensorrt [TensorRT Conversion Context]:Total Device Persistent Memory: 0
INFO:torch_tensorrt [TensorRT Conversion Context]:Total Scratch Memory: 0
INFO:torch_tensorrt [TensorRT Conversion Context]:Total Activation Memory: 0
INFO:torch_tensorrt [TensorRT Conversion Context]:Total Weights Memory: 0
INFO:torch_tensorrt [TensorRT Conversion Context]:Engine generation completed in 0.0770113 seconds.
INFO:torch_tensorrt [TensorRT Conversion Context]:[MemUsageStats] Peak memory usage of TRT CPU/GPU memory allocators: CPU 0 MiB, GPU 1 MiB
INFO:torch_tensorrt [TensorRT Conversion Context]:[MemUsageStats] Peak memory usage during Engine building and serialization: CPU: 3522 MiB
INFO:torch_tensorrt [TensorRT Conversion Context]:Serialized 26 bytes of code generator cache.
INFO:torch_tensorrt [TensorRT Conversion Context]:Serialized 1 timing cache entries
INFO:torch_tensorrt.dynamo.conversion._TRTInterpreter:Build TRT engine elapsed time: 0:00:00.079129
INFO:torch_tensorrt.dynamo.conversion._TRTInterpreter:TRT Engine uses: 10732 bytes of Memory
DEBUG:torch_tensorrt.dynamo._DryRunTracker:
++++++++++++++++++++++++++++++++++++++++++++++++++ Dry-Run Results for Graph ++++++++++++++++++++++++++++++++++++++++++++++++++

The graph consists of 1 Total Operators, of which 1 operators are supported, 100.0% coverage

Compiled with: CompilationSettings(enabled_precisions={<dtype.f32: 7>}, debug=True, workspace_size=0, min_block_size=1, torch_executed_ops=set(), pass_through_build_failures=False, max_aux_streams=None, version_compatible=False, optimization_level=None, use_python_runtime=False, truncate_double=False, use_fast_partitioner=True, enable_experimental_decompositions=False, device=Device(type=DeviceType.GPU, gpu_id=0), require_full_compilation=False, disable_tf32=False, assume_dynamic_shape_support=False, sparse_weights=False, refit=False, engine_capability=<EngineCapability.STANDARD: 1>, num_avg_timing_iters=1, dla_sram_size=1048576, dla_local_dram_size=1073741824, dla_global_dram_size=536870912, dryrun=False, hardware_compatible=False)

  Graph Structure:

   Inputs: List[Tensor: (1, 3, 6, 9)@float32]
    ...
    TRT Engine #1 - Submodule name: _run_on_acc_0
     Engine Inputs: List[Tensor: (1, 3, 6, 9)@float32]
     Number of Operators in Engine: 1
     Engine Outputs: Tensor: (1, 3, 9, 22)@float32
    ...
   Outputs: Tuple(Tensor: (1, 3, 9, 22)@float32)

  ------------------------- Aggregate Stats -------------------------

   Average Number of Operators per TRT Engine: 1.0
   Most Operators in a TRT Engine: 1

  ********** Recommendations **********

   - For minimal graph segmentation, select min_block_size=1 which would generate 1 TRT engine(s)
   - The current level of graph segmentation is equivalent to selecting min_block_size=1 which generates 1 TRT engine(s)
WARNING: [Torch-TensorRT] - Using default stream in enqueueV3() may lead to performance issues due to additional calls to cudaStreamSynchronize() by TensorRT to ensure correct synchronization. Please use non-default stream instead.

@apbose
Copy link
Collaborator

apbose commented Jun 13, 2024

Thanks for the PR and the detailed explanations @HolyWu . I had a doubt in the above I see the signature
%upsample_bicubic2d : [num_users=1] = call_function[target=torch.ops.aten.upsample_bicubic2d.default](args = (%arg0_1, [9, 22], False, 1.5, 2.5), kwargs = {})
So as you said the output shape is mentioned always, but isn't the scale_factor [1.5,2.5] ? What does 1.5, 2.5 denote in the above?

@HolyWu
Copy link
Contributor Author

HolyWu commented Jun 14, 2024

Thanks for the PR and the detailed explanations @HolyWu . I had a doubt in the above I see the signature %upsample_bicubic2d : [num_users=1] = call_function[target=torch.ops.aten.upsample_bicubic2d.default](args = (%arg0_1, [9, 22], False, 1.5, 2.5), kwargs = {}) So as you said the output shape is mentioned always, but isn't the scale_factor [1.5,2.5] ? What does 1.5, 2.5 denote in the above?

In the schema, they are scales_h and scales_w arguments, respectively.

@apbose
Copy link
Collaborator

apbose commented Jul 8, 2024

LGTM! @HolyWu could you rebase the PR? Pending on CI.

@apbose apbose self-requested a review July 8, 2024 22:24
@HolyWu
Copy link
Contributor Author

HolyWu commented Jul 9, 2024

I'll add supports_dynamic_shape argument to the decorator and corresponding tests for dynamic shape since I'm at it.

@HolyWu HolyWu marked this pull request as draft July 14, 2024 12:04
@HolyWu HolyWu changed the title Fix incomplete upsample dynamo converter Overhaul upsample dynamo converter Jul 25, 2024
@HolyWu HolyWu marked this pull request as ready for review July 25, 2024 12:39
@HolyWu
Copy link
Contributor Author

HolyWu commented Jul 25, 2024

Dynamic support has been added. I also managed to override PyTorch's CompositeImplicitAutograd dispatch key so users of torch.nn.functional.interpolate can actually use these converters rather than get decomposed operators.

@HolyWu HolyWu requested a review from apbose July 25, 2024 12:59
@keehyuna
Copy link
Collaborator

LGTM. Tests were successful on my local build. they are running on TRT engine as expected.

@peri044 peri044 merged commit bc00de6 into pytorch:main Jul 31, 2024
4 checks passed
@HolyWu HolyWu deleted the fix_upsample branch July 31, 2024 23:37
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
cla signed component: api [Python] Issues re: Python API component: conversion Issues re: Conversion stage component: converters Issues re: Specific op converters component: dynamo Issues relating to the `torch.compile` or `torch._dynamo.export` paths component: lowering Issues re: The lowering / preprocessing passes component: tests Issues re: Tests
Projects
None yet
Development

Successfully merging this pull request may close these issues.

🐛 [Bug] Incomplete upsample implementation in dynamo converter
5 participants