Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: InstanceNorm decomposition #3288

Merged
merged 2 commits into from
Nov 15, 2024

Conversation

HolyWu
Copy link
Contributor

@HolyWu HolyWu commented Nov 11, 2024

Description

InstanceNorm decomposition by PyTorch leads to inferior performance, and it also causes graph breaks and engine building failure with dynamic shapes in TorchTRT.

Fixes #3265

Type of change

  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)

Checklist:

  • My code follows the style guidelines of this project (You can use the linters)
  • I have performed a self-review of my own code
  • I have commented my code, particularly in hard-to-understand areas and hacks
  • I have made corresponding changes to the documentation
  • I have added tests to verify my fix or my feature
  • New and existing unit tests pass locally with my changes
  • I have added the relevant labels to my PR in so that relevant reviewers are notified

@github-actions github-actions bot added component: tests Issues re: Tests component: lowering Issues re: The lowering / preprocessing passes component: api [Python] Issues re: Python API component: dynamo Issues relating to the `torch.compile` or `torch._dynamo.export` paths labels Nov 11, 2024
@github-actions github-actions bot requested a review from gs-olive November 11, 2024 09:49
@HolyWu
Copy link
Contributor Author

HolyWu commented Nov 11, 2024

Performance

from __future__ import annotations

import os

import numpy as np
import torch
import torch_tensorrt

os.environ["CI_BUILD"] = "1"

times = 10


@torch.inference_mode()
def benchmark(model: torch.nn.Module, inputs: list[torch.Tensor]) -> np.ndarray:
    # Warm up
    for i in range(3):
        model(inputs[i])

    torch.cuda.synchronize()

    start_events = [torch.cuda.Event(enable_timing=True) for _ in range(times)]
    end_events = [torch.cuda.Event(enable_timing=True) for _ in range(times)]

    for i in range(times):
        torch.cuda._sleep(1_000_000)

        start_events[i].record()
        model(inputs[i])
        end_events[i].record()

    torch.cuda.synchronize()

    timings = [s.elapsed_time(e) for s, e in zip(start_events, end_events)]
    return np.array(timings)


class MyModule(torch.nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.m = torch.nn.InstanceNorm2d(64)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.m(x)


torch.manual_seed(12345)
model = MyModule().eval().cuda()
inputs = [torch_tensorrt.Input((32, 64, 128, 256), dtype=torch.float)]

trt_model = torch_tensorrt.compile(
    model,
    ir="dynamo",
    inputs=inputs,
    enabled_precisions={torch.float},
    debug=True,
    min_block_size=1,
)

inputs = [torch.randn((32, 64, 128, 256), device="cuda") for _ in range(times)]
timing = benchmark(trt_model, inputs)

print("")
print("Timing:")
print(f"Min={timing.min()} ms, Mean={timing.mean()} ms, Max={timing.max()} ms")
print("")

with torch.inference_mode():
    for i in range(times):
        torch.testing.assert_close(model(inputs[i]), trt_model(inputs[i]), rtol=5e-3, atol=5e-3)
print("assert_close passed")

torch._dynamo.reset()

Before Patch

DEBUG:torch_tensorrt.dynamo.lowering.passes.remove_detach:Removed 0 detach nodes:
graph():
    %x : [num_users=1] = placeholder[target=x]
    %instance_norm : [num_users=1] = call_function[target=torch.ops.aten.instance_norm.default](args = (%x, None, None, None, None, True, 0.1, 1e-05, True), kwargs = {})
    return (instance_norm,)
DEBUG:torch_tensorrt.dynamo._compiler:Input graph: graph():
    %x : [num_users=1] = placeholder[target=x]
    %view : [num_users=4] = call_function[target=torch.ops.aten.view.default](args = (%x, [1, 2048, 128, 256]), kwargs = {})
    %mean : [num_users=1] = call_function[target=torch.ops.aten.mean.dim](args = (%view, [0, 2, 3], True), kwargs = {})
    %sub : [num_users=1] = call_function[target=torch.ops.aten.sub.Tensor](args = (%view, %mean), kwargs = {})
    %mul : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%sub, %sub), kwargs = {})
    %sum_1 : [num_users=1] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%mul, [0, 2, 3]), kwargs = {})
    %div : [num_users=1] = call_function[target=torch.ops.aten.div.Tensor](args = (%sum_1, 32768.0), kwargs = {})
    %broadcast_in_dim : [num_users=1] = call_function[target=torch.ops.prims.broadcast_in_dim.default](args = (%div, [1, 2048, 1, 1], [1]), kwargs = {})
    %sum_2 : [num_users=1] = call_function[target=torch.ops.prims.sum.default](args = (%view, [0, 2, 3]), kwargs = {})
    %broadcast_in_dim_1 : [num_users=1] = call_function[target=torch.ops.prims.broadcast_in_dim.default](args = (%sum_2, [1, 2048, 1, 1], [1]), kwargs = {})
    %div_1 : [num_users=1] = call_function[target=torch.ops.prims.div.default](args = (%broadcast_in_dim_1, 32768.0), kwargs = {})
    %add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%broadcast_in_dim, 1e-05), kwargs = {})
    %sqrt : [num_users=1] = call_function[target=torch.ops.aten.sqrt.default](args = (%add,), kwargs = {})
    %div_2 : [num_users=1] = call_function[target=torch.ops.aten.div.Tensor](args = (1, %sqrt), kwargs = {})
    %sub_1 : [num_users=1] = call_function[target=torch.ops.aten.sub.Tensor](args = (%view, %div_1), kwargs = {})
    %mul_1 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%sub_1, %div_2), kwargs = {})
    %view_1 : [num_users=1] = call_function[target=torch.ops.aten.view.default](args = (%mul_1, [32, 64, 128, 256]), kwargs = {})
    return (view_1,)
DEBUG:torch_tensorrt.dynamo.lowering.passes.constant_folding:Graph after constant folding:
graph():
    %x : [num_users=1] = placeholder[target=x]
    %view : [num_users=4] = call_function[target=torch.ops.aten.view.default](args = (%x, [1, 2048, 128, 256]), kwargs = {})
    %mean : [num_users=1] = call_function[target=torch.ops.aten.mean.dim](args = (%view, [0, 2, 3], True), kwargs = {})
    %sub : [num_users=1] = call_function[target=torch.ops.aten.sub.Tensor](args = (%view, %mean), kwargs = {})
    %mul : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%sub, %sub), kwargs = {})
    %sum_1 : [num_users=1] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%mul, [0, 2, 3]), kwargs = {})
    %div : [num_users=1] = call_function[target=torch.ops.aten.div.Tensor](args = (%sum_1, 32768.0), kwargs = {})
    %broadcast_in_dim : [num_users=1] = call_function[target=torch.ops.prims.broadcast_in_dim.default](args = (%div, [1, 2048, 1, 1], [1]), kwargs = {})
    %sum_2 : [num_users=1] = call_function[target=torch.ops.prims.sum.default](args = (%view, [0, 2, 3]), kwargs = {})
    %broadcast_in_dim_1 : [num_users=1] = call_function[target=torch.ops.prims.broadcast_in_dim.default](args = (%sum_2, [1, 2048, 1, 1], [1]), kwargs = {})
    %div_1 : [num_users=1] = call_function[target=torch.ops.prims.div.default](args = (%broadcast_in_dim_1, 32768.0), kwargs = {})
    %add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%broadcast_in_dim, 1e-05), kwargs = {})
    %sqrt : [num_users=1] = call_function[target=torch.ops.aten.sqrt.default](args = (%add,), kwargs = {})
    %div_2 : [num_users=1] = call_function[target=torch.ops.aten.div.Tensor](args = (1, %sqrt), kwargs = {})
    %sub_1 : [num_users=1] = call_function[target=torch.ops.aten.sub.Tensor](args = (%view, %div_1), kwargs = {})
    %mul_1 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%sub_1, %div_2), kwargs = {})
    %view_1 : [num_users=1] = call_function[target=torch.ops.aten.view.default](args = (%mul_1, [32, 64, 128, 256]), kwargs = {})
    return (view_1,)
DEBUG:torch_tensorrt.dynamo.lowering.passes.fuse_prims_broadcast:Graph after fusing prims-broadcast paradigm:
graph():
    %x : [num_users=1] = placeholder[target=x]
    %view : [num_users=4] = call_function[target=torch.ops.aten.view.default](args = (%x, [1, 2048, 128, 256]), kwargs = {})
    %mean : [num_users=1] = call_function[target=torch.ops.aten.mean.dim](args = (%view, [0, 2, 3], True), kwargs = {})
    %sub : [num_users=1] = call_function[target=torch.ops.aten.sub.Tensor](args = (%view, %mean), kwargs = {})
    %mul : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%sub, %sub), kwargs = {})
    %sum_1 : [num_users=1] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%mul, [0, 2, 3]), kwargs = {})
    %div : [num_users=1] = call_function[target=torch.ops.aten.div.Tensor](args = (%sum_1, 32768.0), kwargs = {})
    %broadcast_in_dim : [num_users=1] = call_function[target=torch.ops.prims.broadcast_in_dim.default](args = (%div, [1, 2048, 1, 1], [1]), kwargs = {})
    %sum_dim_int_list : [num_users=1] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%view, [0, 2, 3], True), kwargs = {})
    %div_1 : [num_users=1] = call_function[target=torch.ops.prims.div.default](args = (%sum_dim_int_list, 32768.0), kwargs = {})
    %add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%broadcast_in_dim, 1e-05), kwargs = {})
    %sqrt : [num_users=1] = call_function[target=torch.ops.aten.sqrt.default](args = (%add,), kwargs = {})
    %div_2 : [num_users=1] = call_function[target=torch.ops.aten.div.Tensor](args = (1, %sqrt), kwargs = {})
    %sub_1 : [num_users=1] = call_function[target=torch.ops.aten.sub.Tensor](args = (%view, %div_1), kwargs = {})
    %mul_1 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%sub_1, %div_2), kwargs = {})
    %view_1 : [num_users=1] = call_function[target=torch.ops.aten.view.default](args = (%mul_1, [32, 64, 128, 256]), kwargs = {})
    return (view_1,)
DEBUG:torch_tensorrt.dynamo.lowering.passes.view_to_reshape:Graph after replacing view with reshape:
graph():
    %x : [num_users=1] = placeholder[target=x]
    %reshape_default : [num_users=4] = call_function[target=torch.ops.aten.reshape.default](args = (%x, [1, 2048, 128, 256]), kwargs = {})
    %mean : [num_users=1] = call_function[target=torch.ops.aten.mean.dim](args = (%reshape_default, [0, 2, 3], True), kwargs = {})
    %sub : [num_users=1] = call_function[target=torch.ops.aten.sub.Tensor](args = (%reshape_default, %mean), kwargs = {})
    %mul : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%sub, %sub), kwargs = {})
    %sum_1 : [num_users=1] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%mul, [0, 2, 3]), kwargs = {})
    %div : [num_users=1] = call_function[target=torch.ops.aten.div.Tensor](args = (%sum_1, 32768.0), kwargs = {})
    %broadcast_in_dim : [num_users=1] = call_function[target=torch.ops.prims.broadcast_in_dim.default](args = (%div, [1, 2048, 1, 1], [1]), kwargs = {})
    %sum_dim_int_list : [num_users=1] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%reshape_default, [0, 2, 3], True), kwargs = {})
    %div_1 : [num_users=1] = call_function[target=torch.ops.prims.div.default](args = (%sum_dim_int_list, 32768.0), kwargs = {})
    %add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%broadcast_in_dim, 1e-05), kwargs = {})
    %sqrt : [num_users=1] = call_function[target=torch.ops.aten.sqrt.default](args = (%add,), kwargs = {})
    %div_2 : [num_users=1] = call_function[target=torch.ops.aten.div.Tensor](args = (1, %sqrt), kwargs = {})
    %sub_1 : [num_users=1] = call_function[target=torch.ops.aten.sub.Tensor](args = (%reshape_default, %div_1), kwargs = {})
    %mul_1 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%sub_1, %div_2), kwargs = {})
    %reshape_default_1 : [num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%mul_1, [32, 64, 128, 256]), kwargs = {})
    return (reshape_default_1,)
DEBUG:torch_tensorrt.dynamo.lowering.passes.remove_assert_scalar:Removed 0 assert_scalar nodes:
graph():
    %x : [num_users=1] = placeholder[target=x]
    %reshape_default : [num_users=4] = call_function[target=torch.ops.aten.reshape.default](args = (%x, [1, 2048, 128, 256]), kwargs = {})
    %mean : [num_users=1] = call_function[target=torch.ops.aten.mean.dim](args = (%reshape_default, [0, 2, 3], True), kwargs = {})
    %sub : [num_users=1] = call_function[target=torch.ops.aten.sub.Tensor](args = (%reshape_default, %mean), kwargs = {})
    %mul : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%sub, %sub), kwargs = {})
    %sum_1 : [num_users=1] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%mul, [0, 2, 3]), kwargs = {})
    %div : [num_users=1] = call_function[target=torch.ops.aten.div.Tensor](args = (%sum_1, 32768.0), kwargs = {})
    %broadcast_in_dim : [num_users=1] = call_function[target=torch.ops.prims.broadcast_in_dim.default](args = (%div, [1, 2048, 1, 1], [1]), kwargs = {})
    %sum_dim_int_list : [num_users=1] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%reshape_default, [0, 2, 3], True), kwargs = {})
    %div_1 : [num_users=1] = call_function[target=torch.ops.prims.div.default](args = (%sum_dim_int_list, 32768.0), kwargs = {})
    %add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%broadcast_in_dim, 1e-05), kwargs = {})
    %sqrt : [num_users=1] = call_function[target=torch.ops.aten.sqrt.default](args = (%add,), kwargs = {})
    %div_2 : [num_users=1] = call_function[target=torch.ops.aten.div.Tensor](args = (1, %sqrt), kwargs = {})
    %sub_1 : [num_users=1] = call_function[target=torch.ops.aten.sub.Tensor](args = (%reshape_default, %div_1), kwargs = {})
    %mul_1 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%sub_1, %div_2), kwargs = {})
    %reshape_default_1 : [num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%mul_1, [32, 64, 128, 256]), kwargs = {})
    return (reshape_default_1,)
DEBUG:torch_tensorrt.dynamo.lowering.passes.accumulate_fp32_matmul:Skipping FP32 accumulation for matmul layers as use_fp32_acc is not enabled in the compilation settings
DEBUG:torch_tensorrt.dynamo._compiler:Lowered Input graph: graph():
    %x : [num_users=1] = placeholder[target=x]
    %reshape_default : [num_users=4] = call_function[target=torch.ops.aten.reshape.default](args = (%x, [1, 2048, 128, 256]), kwargs = {})
    %mean : [num_users=1] = call_function[target=torch.ops.aten.mean.dim](args = (%reshape_default, [0, 2, 3], True), kwargs = {})
    %sub : [num_users=1] = call_function[target=torch.ops.aten.sub.Tensor](args = (%reshape_default, %mean), kwargs = {})
    %mul : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%sub, %sub), kwargs = {})
    %sum_1 : [num_users=1] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%mul, [0, 2, 3]), kwargs = {})
    %div : [num_users=1] = call_function[target=torch.ops.aten.div.Tensor](args = (%sum_1, 32768.0), kwargs = {})
    %broadcast_in_dim : [num_users=1] = call_function[target=torch.ops.prims.broadcast_in_dim.default](args = (%div, [1, 2048, 1, 1], [1]), kwargs = {})
    %sum_dim_int_list : [num_users=1] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%reshape_default, [0, 2, 3], True), kwargs = {})
    %div_1 : [num_users=1] = call_function[target=torch.ops.prims.div.default](args = (%sum_dim_int_list, 32768.0), kwargs = {})
    %add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%broadcast_in_dim, 1e-05), kwargs = {})
    %sqrt : [num_users=1] = call_function[target=torch.ops.aten.sqrt.default](args = (%add,), kwargs = {})
    %div_2 : [num_users=1] = call_function[target=torch.ops.aten.div.Tensor](args = (1, %sqrt), kwargs = {})
    %sub_1 : [num_users=1] = call_function[target=torch.ops.aten.sub.Tensor](args = (%reshape_default, %div_1), kwargs = {})
    %mul_1 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%sub_1, %div_2), kwargs = {})
    %reshape_default_1 : [num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%mul_1, [32, 64, 128, 256]), kwargs = {})
    return (reshape_default_1,)
DEBUG:torch_tensorrt.dynamo.partitioning._global_partitioner:
Supported Nodes:
- torch.ops.aten.reshape.default + Operator Count: 2
- torch.ops.aten.mean.dim + Operator Count: 1
- torch.ops.aten.sub.Tensor + Operator Count: 2
- torch.ops.aten.mul.Tensor + Operator Count: 2
- torch.ops.aten.sum.dim_IntList + Operator Count: 2
- torch.ops.aten.div.Tensor + Operator Count: 2
- torch.ops.prims.broadcast_in_dim.default + Operator Count: 1
- torch.ops.prims.div.default + Operator Count: 1
- torch.ops.aten.add.Tensor + Operator Count: 1
- torch.ops.aten.sqrt.default + Operator Count: 1

DEBUG:torch_tensorrt.dynamo.partitioning._global_partitioner:
All Nodes Supported

DEBUG:torch_tensorrt.dynamo._compiler:Detected support for 15 operators out of 15 in subgraph.
WARNING:torch_tensorrt.dynamo._compiler:Node sum_dim_int_list of op type call_function does not have metadata. This could sometimes lead to undefined behavior.
WARNING:torch_tensorrt.dynamo._compiler:Some nodes do not have metadata (shape and dtype information). This could lead to problems sometimes if the graph has PyTorch and TensorRT segments.
INFO:torch_tensorrt.dynamo._compiler:Partitioning the graph via the fast partitioner
DEBUG:torch_tensorrt.dynamo.partitioning._adjacency_partitioner:
Number of TensorRT-Accelerated Engines Generated: 1
DEBUG:torch_tensorrt.dynamo.partitioning._adjacency_partitioner:
Supported Nodes:
- torch.ops.aten.reshape.default + Operator Count: 2
- torch.ops.aten.mean.dim + Operator Count: 1
- torch.ops.aten.sub.Tensor + Operator Count: 2
- torch.ops.aten.mul.Tensor + Operator Count: 2
- torch.ops.aten.sum.dim_IntList + Operator Count: 2
- torch.ops.aten.div.Tensor + Operator Count: 2
- torch.ops.prims.broadcast_in_dim.default + Operator Count: 1
- torch.ops.prims.div.default + Operator Count: 1
- torch.ops.aten.add.Tensor + Operator Count: 1
- torch.ops.aten.sqrt.default + Operator Count: 1

DEBUG:torch_tensorrt.dynamo.partitioning._adjacency_partitioner:
All Nodes Supported

DEBUG:torch_tensorrt.dynamo._compiler:Updated metadata for node: _run_on_acc_0 with its corresponding submodule outputs
DEBUG:torch_tensorrt.dynamo._compiler:Converting submodule: _run_on_acc_0
 Input shapes: [(32, 64, 128, 256)]
 graph():
    %x : [num_users=1] = placeholder[target=x]
    %reshape_default : [num_users=4] = call_function[target=torch.ops.aten.reshape.default](args = (%x, [1, 2048, 128, 256]), kwargs = {})
    %mean : [num_users=1] = call_function[target=torch.ops.aten.mean.dim](args = (%reshape_default, [0, 2, 3], True), kwargs = {})
    %sub : [num_users=1] = call_function[target=torch.ops.aten.sub.Tensor](args = (%reshape_default, %mean), kwargs = {})
    %mul : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%sub, %sub), kwargs = {})
    %sum_1 : [num_users=1] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%mul, [0, 2, 3]), kwargs = {})
    %div : [num_users=1] = call_function[target=torch.ops.aten.div.Tensor](args = (%sum_1, 32768.0), kwargs = {})
    %broadcast_in_dim : [num_users=1] = call_function[target=torch.ops.prims.broadcast_in_dim.default](args = (%div, [1, 2048, 1, 1], [1]), kwargs = {})
    %sum_dim_int_list : [num_users=1] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%reshape_default, [0, 2, 3], True), kwargs = {})
    %div_1 : [num_users=1] = call_function[target=torch.ops.prims.div.default](args = (%sum_dim_int_list, 32768.0), kwargs = {})
    %add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%broadcast_in_dim, 1e-05), kwargs = {})
    %sqrt : [num_users=1] = call_function[target=torch.ops.aten.sqrt.default](args = (%add,), kwargs = {})
    %div_2 : [num_users=1] = call_function[target=torch.ops.aten.div.Tensor](args = (1, %sqrt), kwargs = {})
    %sub_1 : [num_users=1] = call_function[target=torch.ops.aten.sub.Tensor](args = (%reshape_default, %div_1), kwargs = {})
    %mul_1 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%sub_1, %div_2), kwargs = {})
    %reshape_default_1 : [num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%mul_1, [32, 64, 128, 256]), kwargs = {})
    return reshape_default_1
DEBUG:torch_tensorrt.dynamo.conversion._TRTInterpreter:Converting node x (kind: x, args: ())
DEBUG:torch_tensorrt.dynamo.conversion._TRTInterpreter:Adding input to in-progress INetwork: x [shape=[32, 64, 128, 256], dtype=DataType.FLOAT]
INFO:torch_tensorrt.dynamo.conversion._TRTInterpreter:Converted node x [x] (Inputs: () | Outputs: (x: (32, 64, 128, 256)@torch.float32))
DEBUG:torch_tensorrt.dynamo.conversion._TRTInterpreter:Converting node m/reshape_default (kind: aten.reshape.default, args: ('x <Node>', ['1 <int>', '2048 <int>', '128 <int>', '256 <int>']))
INFO:torch_tensorrt.dynamo.conversion._TRTInterpreter:Converted node m/reshape_default [aten.reshape.default] (Inputs: (x: (32, 64, 128, 256)@torch.float32, [1, 2048, 128, 256]) | Outputs: (reshape_default: (1, 2048, 128, 256)@torch.float32))
DEBUG:torch_tensorrt.dynamo.conversion._TRTInterpreter:Converting node m/mean (kind: aten.mean.dim, args: ('reshape_default <Node>', ['0 <int>', '2 <int>', '3 <int>'], 'True <bool>'))
INFO:torch_tensorrt.dynamo.conversion._TRTInterpreter:Converted node m/mean [aten.mean.dim] (Inputs: (reshape_default: (1, 2048, 128, 256)@torch.float32, [0, 2, 3], True) | Outputs: (mean: (1, 2048, 1, 1)@torch.float32))
DEBUG:torch_tensorrt.dynamo.conversion._TRTInterpreter:Converting node m/sub (kind: aten.sub.Tensor, args: ('reshape_default <Node>', 'mean <Node>'))
INFO:torch_tensorrt.dynamo.conversion._TRTInterpreter:Converted node m/sub [aten.sub.Tensor] (Inputs: (reshape_default: (1, 2048, 128, 256)@torch.float32, mean: (1, 2048, 1, 1)@torch.float32) | Outputs: (sub: (1, 2048, 128, 256)@torch.float32))
DEBUG:torch_tensorrt.dynamo.conversion._TRTInterpreter:Converting node m/mul (kind: aten.mul.Tensor, args: ('sub <Node>', 'sub <Node>'))
INFO:torch_tensorrt.dynamo.conversion._TRTInterpreter:Converted node m/mul [aten.mul.Tensor] (Inputs: (sub: (1, 2048, 128, 256)@torch.float32, sub: (1, 2048, 128, 256)@torch.float32) | Outputs: (mul: (1, 2048, 128, 256)@torch.float32))
DEBUG:torch_tensorrt.dynamo.conversion._TRTInterpreter:Converting node m/sum_1 (kind: aten.sum.dim_IntList, args: ('mul <Node>', ['0 <int>', '2 <int>', '3 <int>']))
INFO:torch_tensorrt.dynamo.conversion._TRTInterpreter:Converted node m/sum_1 [aten.sum.dim_IntList] (Inputs: (mul: (1, 2048, 128, 256)@torch.float32, [0, 2, 3]) | Outputs: (sum_1: (2048,)@torch.float32))
DEBUG:torch_tensorrt.dynamo.conversion._TRTInterpreter:Converting node m/div (kind: aten.div.Tensor, args: ('sum_1 <Node>', '32768.0 <float>'))
INFO:torch_tensorrt.dynamo.conversion._TRTInterpreter:Converted node m/div [aten.div.Tensor] (Inputs: (sum_1: (2048,)@torch.float32, 32768.0) | Outputs: (div: (2048,)@torch.float32))
DEBUG:torch_tensorrt.dynamo.conversion._TRTInterpreter:Converting node m/broadcast_in_dim (kind: prims.broadcast_in_dim.default, args: ('div <Node>', ['1 <int>', '2048 <int>', '1 <int>', '1 <int>'], ['1 <int>']))
INFO:torch_tensorrt.dynamo.conversion._TRTInterpreter:Converted node m/broadcast_in_dim [prims.broadcast_in_dim.default] (Inputs: (div: (2048,)@torch.float32, [1, 2048, 1, 1], [1]) | Outputs: (broadcast_in_dim: (1, 2048, 1, 1)@torch.float32))
DEBUG:torch_tensorrt.dynamo.conversion._TRTInterpreter:Converting node sum_dim_int_list (kind: aten.sum.dim_IntList, args: ('reshape_default <Node>', ['0 <int>', '2 <int>', '3 <int>'], 'True <bool>'))
INFO:torch_tensorrt.dynamo.conversion._TRTInterpreter:Converted node sum_dim_int_list [aten.sum.dim_IntList] (Inputs: (reshape_default: (1, 2048, 128, 256)@torch.float32, [0, 2, 3], True) | Outputs: (sum_dim_int_list: ))
DEBUG:torch_tensorrt.dynamo.conversion._TRTInterpreter:Converting node m/div_1 (kind: prims.div.default, args: ('sum_dim_int_list <Node>', '32768.0 <float>'))
INFO:torch_tensorrt.dynamo.conversion._TRTInterpreter:Converted node m/div_1 [prims.div.default] (Inputs: (sum_dim_int_list: , 32768.0) | Outputs: (div_1: (1, 2048, 1, 1)@torch.float32))
DEBUG:torch_tensorrt.dynamo.conversion._TRTInterpreter:Converting node m/add (kind: aten.add.Tensor, args: ('broadcast_in_dim <Node>', '1e-05 <float>'))
INFO:torch_tensorrt.dynamo.conversion._TRTInterpreter:Converted node m/add [aten.add.Tensor] (Inputs: (broadcast_in_dim: (1, 2048, 1, 1)@torch.float32, 1e-05) | Outputs: (add: (1, 2048, 1, 1)@torch.float32))
DEBUG:torch_tensorrt.dynamo.conversion._TRTInterpreter:Converting node m/sqrt (kind: aten.sqrt.default, args: ('add <Node>',))
INFO:torch_tensorrt.dynamo.conversion._TRTInterpreter:Converted node m/sqrt [aten.sqrt.default] (Inputs: (add: (1, 2048, 1, 1)@torch.float32) | Outputs: (sqrt: (1, 2048, 1, 1)@torch.float32))
DEBUG:torch_tensorrt.dynamo.conversion._TRTInterpreter:Converting node m/div_2 (kind: aten.div.Tensor, args: ('1 <int>', 'sqrt <Node>'))
INFO:torch_tensorrt.dynamo.conversion._TRTInterpreter:Converted node m/div_2 [aten.div.Tensor] (Inputs: (1, sqrt: (1, 2048, 1, 1)@torch.float32) | Outputs: (div_2: (1, 2048, 1, 1)@torch.float32))
DEBUG:torch_tensorrt.dynamo.conversion._TRTInterpreter:Converting node m/sub_1 (kind: aten.sub.Tensor, args: ('reshape_default <Node>', 'div_1 <Node>'))
INFO:torch_tensorrt.dynamo.conversion._TRTInterpreter:Converted node m/sub_1 [aten.sub.Tensor] (Inputs: (reshape_default: (1, 2048, 128, 256)@torch.float32, div_1: (1, 2048, 1, 1)@torch.float32) | Outputs: (sub_1: (1, 2048, 128, 256)@torch.float32))
DEBUG:torch_tensorrt.dynamo.conversion._TRTInterpreter:Converting node m/mul_1 (kind: aten.mul.Tensor, args: ('sub_1 <Node>', 'div_2 <Node>'))
INFO:torch_tensorrt.dynamo.conversion._TRTInterpreter:Converted node m/mul_1 [aten.mul.Tensor] (Inputs: (sub_1: (1, 2048, 128, 256)@torch.float32, div_2: (1, 2048, 1, 1)@torch.float32) | Outputs: (mul_1: (1, 2048, 128, 256)@torch.float32))
DEBUG:torch_tensorrt.dynamo.conversion._TRTInterpreter:Converting node m/reshape_default_1 (kind: aten.reshape.default, args: ('mul_1 <Node>', ['32 <int>', '64 <int>', '128 <int>', '256 <int>']))
INFO:torch_tensorrt.dynamo.conversion._TRTInterpreter:Converted node m/reshape_default_1 [aten.reshape.default] (Inputs: (mul_1: (1, 2048, 128, 256)@torch.float32, [32, 64, 128, 256]) | Outputs: (reshape_default_1: (32, 64, 128, 256)@torch.float32))
DEBUG:torch_tensorrt.dynamo.conversion._TRTInterpreter:Converting node output (kind: output, args: ('reshape_default_1 <Node>',))
DEBUG:torch_tensorrt.dynamo.conversion._TRTInterpreter:Marking output output0 [shape=(32, 64, 128, 256), dtype=DataType.FLOAT]
INFO:torch_tensorrt.dynamo.conversion._TRTInterpreter:Converted node output [output] (Inputs: (reshape_default_1: (32, 64, 128, 256)@torch.float32) | Outputs: (output: ))
INFO:torch_tensorrt.dynamo.conversion._TRTInterpreter:TRT INetwork construction elapsed time: 0:00:00.018041
INFO:torch_tensorrt.dynamo.conversion._TRTInterpreter:Not found cached TRT engines. Start building engine.
INFO:torch_tensorrt.dynamo.conversion._TRTInterpreter:Build TRT engine elapsed time: 0:00:26.714361
INFO:torch_tensorrt.dynamo.conversion._TRTInterpreter:TRT Engine uses: 67492 bytes of Memory
DEBUG: [Torch-TensorRT] - Deserializing Device Info: 0%8%9%0%NVIDIA GeForce RTX 4060 Ti
DEBUG: [Torch-TensorRT] - Deserialized Device Info: Device(ID: 0, Name: NVIDIA GeForce RTX 4060 Ti, SM Capability: 8.9, Type: GPU)
DEBUG: [Torch-TensorRT] - Target Device: Device(ID: 0, Name: NVIDIA GeForce RTX 4060 Ti, SM Capability: 8.9, Type: GPU)
DEBUG: [Torch-TensorRT] - Setting Device(ID: 0, Name: NVIDIA GeForce RTX 4060 Ti, SM Capability: 8.9, Type: GPU) as active device
INFO: [Torch-TensorRT] - Loaded engine size: 0 MiB
DEBUG: [Torch-TensorRT] - Deserialization required 1047 microseconds.
DEBUG: [Torch-TensorRT] - Total per-runner device persistent memory is 0
DEBUG: [Torch-TensorRT] - Total per-runner host persistent memory is 1456
DEBUG: [Torch-TensorRT] - Allocated device scratch memory of size 268451840
DEBUG: [Torch-TensorRT] - - Runner scratch: 268451840 bytes
INFO: [Torch-TensorRT] - [MemUsageChange] TensorRT-managed allocation in IExecutionContext creation: CPU +0, GPU +256, now: CPU 0, GPU 256 (MiB)
DEBUG: [Torch-TensorRT] - CUDA lazy loading is enabled.
DEBUG: [Torch-TensorRT] - Input binding name: x has TensorRT binding index: 0, Torch binding index: 0
DEBUG: [Torch-TensorRT] - Output binding name: output0 has TensorRT binding index: 1, Torch binding index: 1
DEBUG: [Torch-TensorRT] - Torch-TensorRT TensorRT Engine:
  Name: _run_on_acc_0_engine
  Inputs: [
    id: 0
      name: x
      shape: [32, 64, 128, 256]
      dtype: Float
  ]
  Outputs: [
    id: 0
      name: output0
      shape: [32, 64, 128, 256]
      dtype: Float
  ]
  Device: Device(ID: 0, Name: NVIDIA GeForce RTX 4060 Ti, SM Capability: 8.9, Type: GPU)
  Hardware Compatibility: Disabled
  Target Platform: windows_x86_64

DEBUG:torch_tensorrt.dynamo._DryRunTracker:
++++++++++++++++++++++++++++++++++++++++++++++++++ Dry-Run Results for Graph ++++++++++++++++++++++++++++++++++++++++++++++++++

The graph consists of 15 Total Operators, of which 15 operators are supported, 100.0% coverage

Compiled with: CompilationSettings(enabled_precisions={<dtype.f32: 7>}, debug=True, workspace_size=0, min_block_size=1, torch_executed_ops=set(), pass_through_build_failures=False, max_aux_streams=None, version_compatible=False, optimization_level=None, use_python_runtime=False, truncate_double=False, use_fast_partitioner=True, enable_experimental_decompositions=False, device=Device(type=DeviceType.GPU, gpu_id=0), require_full_compilation=False, disable_tf32=False, assume_dynamic_shape_support=False, sparse_weights=False, make_refittable=False, engine_capability=<EngineCapability.STANDARD: 1>, num_avg_timing_iters=1, dla_sram_size=1048576, dla_local_dram_size=1073741824, dla_global_dram_size=536870912, dryrun=False, hardware_compatible=False, timing_cache_path='C:\\Users\\HolyWu\\AppData\\Local\\Temp\\torch_tensorrt_engine_cache\\timing_cache.bin', lazy_engine_init=False, cache_built_engines=False, reuse_cached_engines=False, use_explicit_typing=False, use_fp32_acc=False, enable_weight_streaming=False, enable_cross_compile_for_windows=False)

  Graph Structure:

   Inputs: List[Tensor: (32, 64, 128, 256)@float32]
    ...
    TRT Engine #1 - Submodule name: _run_on_acc_0
     Engine Inputs: List[Tensor: (32, 64, 128, 256)@float32]
     Number of Operators in Engine: 15
     Engine Outputs: List[Tensor: (32, 64, 128, 256)@float32]
    ...
   Outputs: List[Tensor: (32, 64, 128, 256)@float32]

  ------------------------- Aggregate Stats -------------------------

   Average Number of Operators per TRT Engine: 15.0
   Most Operators in a TRT Engine: 15

  ********** Recommendations **********

   - For minimal graph segmentation, select min_block_size=15 which would generate 1 TRT engine(s)
   - The current level of graph segmentation is equivalent to selecting min_block_size=15 which generates 1 TRT engine(s)
DEBUG: [Torch-TensorRT] - Attempting to run engine (ID: _run_on_acc_0_engine); Hardware Compatible: 0
DEBUG: [Torch-TensorRT] - Input Name: x Shape: [32, 64, 128, 256]
DEBUG: [Torch-TensorRT] - Output Name: output0 Shape: [32, 64, 128, 256]

Timing:
Min=13.204704284667969 ms, Mean=13.591363143920898 ms, Max=14.2673921585083 ms

assert_close passed

After Patch (need #3273 to be merged first)

DEBUG:torch_tensorrt.dynamo.lowering.passes.remove_detach:Removed 0 detach nodes:
graph():
    %x : [num_users=1] = placeholder[target=x]
    %instance_norm : [num_users=1] = call_function[target=torch.ops.aten.instance_norm.default](args = (%x, None, None, None, None, True, 0.1, 1e-05, True), kwargs = {})
    return (instance_norm,)
DEBUG:torch_tensorrt.dynamo._compiler:Input graph: graph():
    %x : [num_users=1] = placeholder[target=x]
    %native_group_norm : [num_users=1] = call_function[target=torch.ops.aten.native_group_norm.default](args = (%x, None, None, 32, 64, 32768, 64, 1e-05), kwargs = {})
    %getitem : [num_users=1] = call_function[target=operator.getitem](args = (%native_group_norm, 0), kwargs = {})
    return (getitem,)
DEBUG:torch_tensorrt.dynamo.lowering.passes.constant_folding:Graph after constant folding:
graph():
    %x : [num_users=1] = placeholder[target=x]
    %native_group_norm : [num_users=1] = call_function[target=torch.ops.aten.native_group_norm.default](args = (%x, None, None, 32, 64, 32768, 64, 1e-05), kwargs = {})
    %getitem : [num_users=1] = call_function[target=operator.getitem](args = (%native_group_norm, 0), kwargs = {})
    return (getitem,)
DEBUG:torch_tensorrt.dynamo.lowering.passes.remove_assert_scalar:Removed 0 assert_scalar nodes:
graph():
    %x : [num_users=1] = placeholder[target=x]
    %native_group_norm : [num_users=1] = call_function[target=torch.ops.aten.native_group_norm.default](args = (%x, None, None, 32, 64, 32768, 64, 1e-05), kwargs = {})
    %getitem : [num_users=1] = call_function[target=operator.getitem](args = (%native_group_norm, 0), kwargs = {})
    return (getitem,)
DEBUG:torch_tensorrt.dynamo.lowering.passes.accumulate_fp32_matmul:Skipping FP32 accumulation for matmul layers as use_fp32_acc is not enabled in the compilation settings
DEBUG:torch_tensorrt.dynamo._compiler:Lowered Input graph: graph():
    %x : [num_users=1] = placeholder[target=x]
    %native_group_norm : [num_users=1] = call_function[target=torch.ops.aten.native_group_norm.default](args = (%x, None, None, 32, 64, 32768, 64, 1e-05), kwargs = {})
    %getitem : [num_users=1] = call_function[target=operator.getitem](args = (%native_group_norm, 0), kwargs = {})
    return (getitem,)
DEBUG:torch_tensorrt.dynamo.partitioning._global_partitioner:
Supported Nodes:
- torch.ops.aten.native_group_norm.default + Operator Count: 1
- _operator.getitem + Operator Count: 1

DEBUG:torch_tensorrt.dynamo.partitioning._global_partitioner:
All Nodes Supported

DEBUG:torch_tensorrt.dynamo._compiler:Detected support for 2 operators out of 2 in subgraph.
INFO:torch_tensorrt.dynamo._compiler:Partitioning the graph via the fast partitioner
DEBUG:torch_tensorrt.dynamo.partitioning._adjacency_partitioner:
Number of TensorRT-Accelerated Engines Generated: 1
DEBUG:torch_tensorrt.dynamo.partitioning._adjacency_partitioner:
Supported Nodes:
- torch.ops.aten.native_group_norm.default + Operator Count: 1
- _operator.getitem + Operator Count: 1

DEBUG:torch_tensorrt.dynamo.partitioning._adjacency_partitioner:
All Nodes Supported

DEBUG:torch_tensorrt.dynamo._compiler:Updated metadata for node: _run_on_acc_0 with its corresponding submodule outputs
DEBUG:torch_tensorrt.dynamo._compiler:Converting submodule: _run_on_acc_0
 Input shapes: [(32, 64, 128, 256)]
 graph():
    %x : [num_users=1] = placeholder[target=x]
    %native_group_norm : [num_users=1] = call_function[target=torch.ops.aten.native_group_norm.default](args = (%x, None, None, 32, 64, 32768, 64, 1e-05), kwargs = {})
    %getitem : [num_users=1] = call_function[target=operator.getitem](args = (%native_group_norm, 0), kwargs = {})
    return getitem
DEBUG:torch_tensorrt.dynamo.conversion._TRTInterpreter:Converting node x (kind: x, args: ())
DEBUG:torch_tensorrt.dynamo.conversion._TRTInterpreter:Adding input to in-progress INetwork: x [shape=[32, 64, 128, 256], dtype=DataType.FLOAT]
INFO:torch_tensorrt.dynamo.conversion._TRTInterpreter:Converted node x [x] (Inputs: () | Outputs: (x: (32, 64, 128, 256)@torch.float32))
DEBUG:torch_tensorrt.dynamo.conversion._TRTInterpreter:Converting node m/native_group_norm (kind: aten.native_group_norm.default, args: ('x <Node>', 'None <NoneType>', 'None <NoneType>', '32 <int>', '64 <int>', '32768 <int>', '64 <int>', '1e-05 <float>'))
INFO:torch_tensorrt.dynamo.conversion._TRTInterpreter:Converted node m/native_group_norm [aten.native_group_norm.default] (Inputs: (x: (32, 64, 128, 256)@torch.float32, None, None, 32, 64, 32768, 64, 1e-05) | Outputs: (native_group_norm: ((32, 64, 128, 256)@torch.float32, (32, 64)@torch.float32, (32, 64)@torch.float32)))
DEBUG:torch_tensorrt.dynamo.conversion._TRTInterpreter:Converting node m/getitem (kind: <built-in function getitem>, args: ('native_group_norm <Node>', '0 <int>'))
DEBUG:torch_tensorrt.dynamo.conversion.ops_evaluators:Evaluating _operator.getitem on object with name: m/getitem
INFO:torch_tensorrt.dynamo.conversion._TRTInterpreter:Converted node m/getitem [<built-in function getitem>] (Inputs: (native_group_norm: ((32, 64, 128, 256)@torch.float32, (32, 64)@torch.float32, (32, 64)@torch.float32), 0) | Outputs: (getitem: (32, 64, 128, 256)@torch.float32))
DEBUG:torch_tensorrt.dynamo.conversion._TRTInterpreter:Converting node output (kind: output, args: ('getitem <Node>',))
DEBUG:torch_tensorrt.dynamo.conversion._TRTInterpreter:Marking output output0 [shape=(32, 64, 128, 256), dtype=DataType.FLOAT]
INFO:torch_tensorrt.dynamo.conversion._TRTInterpreter:Converted node output [output] (Inputs: (getitem: (32, 64, 128, 256)@torch.float32) | Outputs: (output: ))
INFO:torch_tensorrt.dynamo.conversion._TRTInterpreter:TRT INetwork construction elapsed time: 0:00:00.003032
INFO:torch_tensorrt.dynamo.conversion._TRTInterpreter:Not found cached TRT engines. Start building engine.
INFO:torch_tensorrt.dynamo.conversion._TRTInterpreter:Build TRT engine elapsed time: 0:00:00.278343
INFO:torch_tensorrt.dynamo.conversion._TRTInterpreter:TRT Engine uses: 56588 bytes of Memory
DEBUG: [Torch-TensorRT] - Deserializing Device Info: 0%8%9%0%NVIDIA GeForce RTX 4060 Ti
DEBUG: [Torch-TensorRT] - Deserialized Device Info: Device(ID: 0, Name: NVIDIA GeForce RTX 4060 Ti, SM Capability: 8.9, Type: GPU)
DEBUG: [Torch-TensorRT] - Target Device: Device(ID: 0, Name: NVIDIA GeForce RTX 4060 Ti, SM Capability: 8.9, Type: GPU)
DEBUG: [Torch-TensorRT] - Setting Device(ID: 0, Name: NVIDIA GeForce RTX 4060 Ti, SM Capability: 8.9, Type: GPU) as active device
INFO: [Torch-TensorRT] - Loaded engine size: 0 MiB
DEBUG: [Torch-TensorRT] - Deserialization required 103 microseconds.
DEBUG: [Torch-TensorRT] - Total per-runner device persistent memory is 0
DEBUG: [Torch-TensorRT] - Total per-runner host persistent memory is 80
DEBUG: [Torch-TensorRT] - Allocated device scratch memory of size 512
DEBUG: [Torch-TensorRT] - - Runner scratch: 512 bytes
INFO: [Torch-TensorRT] - [MemUsageChange] TensorRT-managed allocation in IExecutionContext creation: CPU +0, GPU +0, now: CPU 0, GPU 0 (MiB)
DEBUG: [Torch-TensorRT] - CUDA lazy loading is enabled.
DEBUG: [Torch-TensorRT] - Input binding name: x has TensorRT binding index: 0, Torch binding index: 0
DEBUG: [Torch-TensorRT] - Output binding name: output0 has TensorRT binding index: 1, Torch binding index: 1
DEBUG: [Torch-TensorRT] - Torch-TensorRT TensorRT Engine:
  Name: _run_on_acc_0_engine
  Inputs: [
    id: 0
      name: x
      shape: [32, 64, 128, 256]
      dtype: Float
  ]
  Outputs: [
    id: 0
      name: output0
      shape: [32, 64, 128, 256]
      dtype: Float
  ]
  Device: Device(ID: 0, Name: NVIDIA GeForce RTX 4060 Ti, SM Capability: 8.9, Type: GPU)
  Hardware Compatibility: Disabled
  Target Platform: windows_x86_64

DEBUG:torch_tensorrt.dynamo._DryRunTracker:
++++++++++++++++++++++++++++++++++++++++++++++++++ Dry-Run Results for Graph ++++++++++++++++++++++++++++++++++++++++++++++++++

The graph consists of 2 Total Operators, of which 2 operators are supported, 100.0% coverage

Compiled with: CompilationSettings(enabled_precisions={<dtype.f32: 7>}, debug=True, workspace_size=0, min_block_size=1, torch_executed_ops=set(), pass_through_build_failures=False, max_aux_streams=None, version_compatible=False, optimization_level=None, use_python_runtime=False, truncate_double=False, use_fast_partitioner=True, enable_experimental_decompositions=False, device=Device(type=DeviceType.GPU, gpu_id=0), require_full_compilation=False, disable_tf32=False, assume_dynamic_shape_support=False, sparse_weights=False, make_refittable=False, engine_capability=<EngineCapability.STANDARD: 1>, num_avg_timing_iters=1, dla_sram_size=1048576, dla_local_dram_size=1073741824, dla_global_dram_size=536870912, dryrun=False, hardware_compatible=False, timing_cache_path='C:\\Users\\HolyWu\\AppData\\Local\\Temp\\torch_tensorrt_engine_cache\\timing_cache.bin', lazy_engine_init=False, cache_built_engines=False, reuse_cached_engines=False, use_explicit_typing=False, use_fp32_acc=False, enable_weight_streaming=False, enable_cross_compile_for_windows=False)

  Graph Structure:

   Inputs: List[Tensor: (32, 64, 128, 256)@float32]
    ...
    TRT Engine #1 - Submodule name: _run_on_acc_0
     Engine Inputs: List[Tensor: (32, 64, 128, 256)@float32]
     Number of Operators in Engine: 2
     Engine Outputs: List[Tensor: (32, 64, 128, 256)@float32]
    ...
   Outputs: List[Tensor: (32, 64, 128, 256)@float32]

  ------------------------- Aggregate Stats -------------------------

   Average Number of Operators per TRT Engine: 2.0
   Most Operators in a TRT Engine: 2

  ********** Recommendations **********

   - For minimal graph segmentation, select min_block_size=2 which would generate 1 TRT engine(s)
   - The current level of graph segmentation is equivalent to selecting min_block_size=2 which generates 1 TRT engine(s)
DEBUG: [Torch-TensorRT] - Attempting to run engine (ID: _run_on_acc_0_engine); Hardware Compatible: 0
DEBUG: [Torch-TensorRT] - Input Name: x Shape: [32, 64, 128, 256]
DEBUG: [Torch-TensorRT] - Output Name: output0 Shape: [32, 64, 128, 256]

Timing:
Min=2.1780478954315186 ms, Mean=2.2668895959854125 ms, Max=2.6511359214782715 ms

assert_close passed

@HolyWu
Copy link
Contributor Author

HolyWu commented Nov 11, 2024

Dynamic Shapes

import os

import torch
import torch_tensorrt

os.environ["CI_BUILD"] = "1"


class MyModule(torch.nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.m = torch.nn.InstanceNorm2d(4)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.m(x)


with torch.inference_mode():
    model = MyModule().eval().cuda()

    inputs = [
        torch_tensorrt.Input(min_shape=(2, 4, 2, 2), opt_shape=(2, 4, 6, 8), max_shape=(8, 4, 8, 8), dtype=torch.float)
    ]

    trt_model = torch_tensorrt.compile(
        model,
        "dynamo",
        inputs,
        enabled_precisions={torch.float},
        debug=True,
        min_block_size=1,
    )

    inputs = [torch.randn((3, 4, 5, 6), device="cuda")]

    torch.testing.assert_close(trt_model(*inputs), model(*inputs), rtol=5e-03, atol=5e-03)
    print("assert_close passed")

Before Patch

DEBUG:torch_tensorrt.dynamo.lowering.passes.remove_detach:Removed 0 detach nodes:
graph():
    %x : [num_users=1] = placeholder[target=x]
    %instance_norm : [num_users=1] = call_function[target=torch.ops.aten.instance_norm.default](args = (%x, None, None, None, None, True, 0.1, 1e-05, True), kwargs = {})
    return (instance_norm,)
DEBUG:torch_tensorrt.dynamo._compiler:Input graph: graph():
    %x : [num_users=4] = placeholder[target=x]
    %sym_size_int_4 : [num_users=2] = call_function[target=torch.ops.aten.sym_size.int](args = (%x, 0), kwargs = {})
    %sym_size_int_5 : [num_users=3] = call_function[target=torch.ops.aten.sym_size.int](args = (%x, 2), kwargs = {})
    %sym_size_int_6 : [num_users=3] = call_function[target=torch.ops.aten.sym_size.int](args = (%x, 3), kwargs = {})
    %mul : [num_users=3] = call_function[target=operator.mul](args = (%sym_size_int_4, 4), kwargs = {})
    %view : [num_users=4] = call_function[target=torch.ops.aten.view.default](args = (%x, [1, %mul, %sym_size_int_5, %sym_size_int_6]), kwargs = {})
    %mean : [num_users=1] = call_function[target=torch.ops.aten.mean.dim](args = (%view, [0, 2, 3], True), kwargs = {})
    %sub_3 : [num_users=1] = call_function[target=torch.ops.aten.sub.Tensor](args = (%view, %mean), kwargs = {})
    %mul_14 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%sub_3, %sub_3), kwargs = {})
    %sum_1 : [num_users=1] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%mul_14, [0, 2, 3]), kwargs = {})
    %div : [num_users=1] = call_function[target=torch.ops.aten.div.Tensor](args = (%sum_1, 48.0), kwargs = {})
    %broadcast_in_dim : [num_users=1] = call_function[target=torch.ops.prims.broadcast_in_dim.default](args = (%div, [1, %mul, 1, 1], [1]), kwargs = {})
    %sum_2 : [num_users=1] = call_function[target=torch.ops.prims.sum.default](args = (%view, [0, 2, 3]), kwargs = {})
    %broadcast_in_dim_1 : [num_users=1] = call_function[target=torch.ops.prims.broadcast_in_dim.default](args = (%sum_2, [1, %mul, 1, 1], [1]), kwargs = {})
    %mul_15 : [num_users=1] = call_function[target=operator.mul](args = (%sym_size_int_5, %sym_size_int_6), kwargs = {})
    %sym_float : [num_users=1] = call_function[target=torch.sym_float](args = (%mul_15,), kwargs = {})
    %div_1 : [num_users=1] = call_function[target=torch.ops.prims.div.default](args = (%broadcast_in_dim_1, %sym_float), kwargs = {})
    %add_4 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%broadcast_in_dim, 1e-05), kwargs = {})
    %sqrt : [num_users=1] = call_function[target=torch.ops.aten.sqrt.default](args = (%add_4,), kwargs = {})
    %div_2 : [num_users=1] = call_function[target=torch.ops.aten.div.Tensor](args = (1, %sqrt), kwargs = {})
    %sub_4 : [num_users=1] = call_function[target=torch.ops.aten.sub.Tensor](args = (%view, %div_1), kwargs = {})
    %mul_16 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%sub_4, %div_2), kwargs = {})
    %view_1 : [num_users=1] = call_function[target=torch.ops.aten.view.default](args = (%mul_16, [%sym_size_int_4, 4, %sym_size_int_5, %sym_size_int_6]), kwargs = {})
    return (view_1,)
DEBUG:torch_tensorrt.dynamo.lowering.passes.constant_folding:Graph after constant folding:
graph():
    %x : [num_users=4] = placeholder[target=x]
    %sym_size_int_4 : [num_users=2] = call_function[target=torch.ops.aten.sym_size.int](args = (%x, 0), kwargs = {})
    %sym_size_int_5 : [num_users=3] = call_function[target=torch.ops.aten.sym_size.int](args = (%x, 2), kwargs = {})
    %sym_size_int_6 : [num_users=3] = call_function[target=torch.ops.aten.sym_size.int](args = (%x, 3), kwargs = {})
    %mul : [num_users=3] = call_function[target=operator.mul](args = (%sym_size_int_4, 4), kwargs = {})
    %view : [num_users=4] = call_function[target=torch.ops.aten.view.default](args = (%x, [1, %mul, %sym_size_int_5, %sym_size_int_6]), kwargs = {})
    %mean : [num_users=1] = call_function[target=torch.ops.aten.mean.dim](args = (%view, [0, 2, 3], True), kwargs = {})
    %sub_3 : [num_users=1] = call_function[target=torch.ops.aten.sub.Tensor](args = (%view, %mean), kwargs = {})
    %mul_14 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%sub_3, %sub_3), kwargs = {})
    %sum_1 : [num_users=1] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%mul_14, [0, 2, 3]), kwargs = {})
    %div : [num_users=1] = call_function[target=torch.ops.aten.div.Tensor](args = (%sum_1, 48.0), kwargs = {})
    %broadcast_in_dim : [num_users=1] = call_function[target=torch.ops.prims.broadcast_in_dim.default](args = (%div, [1, %mul, 1, 1], [1]), kwargs = {})
    %sum_2 : [num_users=1] = call_function[target=torch.ops.prims.sum.default](args = (%view, [0, 2, 3]), kwargs = {})
    %broadcast_in_dim_1 : [num_users=1] = call_function[target=torch.ops.prims.broadcast_in_dim.default](args = (%sum_2, [1, %mul, 1, 1], [1]), kwargs = {})
    %mul_15 : [num_users=1] = call_function[target=operator.mul](args = (%sym_size_int_5, %sym_size_int_6), kwargs = {})
    %sym_float : [num_users=1] = call_function[target=torch.sym_float](args = (%mul_15,), kwargs = {})
    %div_1 : [num_users=1] = call_function[target=torch.ops.prims.div.default](args = (%broadcast_in_dim_1, %sym_float), kwargs = {})
    %add_4 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%broadcast_in_dim, 1e-05), kwargs = {})
    %sqrt : [num_users=1] = call_function[target=torch.ops.aten.sqrt.default](args = (%add_4,), kwargs = {})
    %div_2 : [num_users=1] = call_function[target=torch.ops.aten.div.Tensor](args = (1, %sqrt), kwargs = {})
    %sub_4 : [num_users=1] = call_function[target=torch.ops.aten.sub.Tensor](args = (%view, %div_1), kwargs = {})
    %mul_16 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%sub_4, %div_2), kwargs = {})
    %view_1 : [num_users=1] = call_function[target=torch.ops.aten.view.default](args = (%mul_16, [%sym_size_int_4, 4, %sym_size_int_5, %sym_size_int_6]), kwargs = {})
    return (view_1,)
DEBUG:torch_tensorrt.dynamo.lowering.passes.view_to_reshape:Graph after replacing view with reshape:
graph():
    %x : [num_users=4] = placeholder[target=x]
    %sym_size_int_4 : [num_users=2] = call_function[target=torch.ops.aten.sym_size.int](args = (%x, 0), kwargs = {})
    %sym_size_int_5 : [num_users=3] = call_function[target=torch.ops.aten.sym_size.int](args = (%x, 2), kwargs = {})
    %sym_size_int_6 : [num_users=3] = call_function[target=torch.ops.aten.sym_size.int](args = (%x, 3), kwargs = {})
    %mul : [num_users=3] = call_function[target=operator.mul](args = (%sym_size_int_4, 4), kwargs = {})
    %reshape_default : [num_users=4] = call_function[target=torch.ops.aten.reshape.default](args = (%x, [1, %mul, %sym_size_int_5, %sym_size_int_6]), kwargs = {})
    %mean : [num_users=1] = call_function[target=torch.ops.aten.mean.dim](args = (%reshape_default, [0, 2, 3], True), kwargs = {})
    %sub_3 : [num_users=1] = call_function[target=torch.ops.aten.sub.Tensor](args = (%reshape_default, %mean), kwargs = {})
    %mul_14 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%sub_3, %sub_3), kwargs = {})
    %sum_1 : [num_users=1] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%mul_14, [0, 2, 3]), kwargs = {})
    %div : [num_users=1] = call_function[target=torch.ops.aten.div.Tensor](args = (%sum_1, 48.0), kwargs = {})
    %broadcast_in_dim : [num_users=1] = call_function[target=torch.ops.prims.broadcast_in_dim.default](args = (%div, [1, %mul, 1, 1], [1]), kwargs = {})
    %sum_2 : [num_users=1] = call_function[target=torch.ops.prims.sum.default](args = (%reshape_default, [0, 2, 3]), kwargs = {})
    %broadcast_in_dim_1 : [num_users=1] = call_function[target=torch.ops.prims.broadcast_in_dim.default](args = (%sum_2, [1, %mul, 1, 1], [1]), kwargs = {})
    %mul_15 : [num_users=1] = call_function[target=operator.mul](args = (%sym_size_int_5, %sym_size_int_6), kwargs = {})
    %sym_float : [num_users=1] = call_function[target=torch.sym_float](args = (%mul_15,), kwargs = {})
    %div_1 : [num_users=1] = call_function[target=torch.ops.prims.div.default](args = (%broadcast_in_dim_1, %sym_float), kwargs = {})
    %add_4 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%broadcast_in_dim, 1e-05), kwargs = {})
    %sqrt : [num_users=1] = call_function[target=torch.ops.aten.sqrt.default](args = (%add_4,), kwargs = {})
    %div_2 : [num_users=1] = call_function[target=torch.ops.aten.div.Tensor](args = (1, %sqrt), kwargs = {})
    %sub_4 : [num_users=1] = call_function[target=torch.ops.aten.sub.Tensor](args = (%reshape_default, %div_1), kwargs = {})
    %mul_16 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%sub_4, %div_2), kwargs = {})
    %reshape_default_1 : [num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%mul_16, [%sym_size_int_4, 4, %sym_size_int_5, %sym_size_int_6]), kwargs = {})
    return (reshape_default_1,)
DEBUG:torch_tensorrt.dynamo.lowering.passes.remove_assert_scalar:Removed 0 assert_scalar nodes:
graph():
    %x : [num_users=4] = placeholder[target=x]
    %sym_size_int_4 : [num_users=2] = call_function[target=torch.ops.aten.sym_size.int](args = (%x, 0), kwargs = {})
    %sym_size_int_5 : [num_users=3] = call_function[target=torch.ops.aten.sym_size.int](args = (%x, 2), kwargs = {})
    %sym_size_int_6 : [num_users=3] = call_function[target=torch.ops.aten.sym_size.int](args = (%x, 3), kwargs = {})
    %mul : [num_users=3] = call_function[target=operator.mul](args = (%sym_size_int_4, 4), kwargs = {})
    %reshape_default : [num_users=4] = call_function[target=torch.ops.aten.reshape.default](args = (%x, [1, %mul, %sym_size_int_5, %sym_size_int_6]), kwargs = {})
    %mean : [num_users=1] = call_function[target=torch.ops.aten.mean.dim](args = (%reshape_default, [0, 2, 3], True), kwargs = {})
    %sub_3 : [num_users=1] = call_function[target=torch.ops.aten.sub.Tensor](args = (%reshape_default, %mean), kwargs = {})
    %mul_14 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%sub_3, %sub_3), kwargs = {})
    %sum_1 : [num_users=1] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%mul_14, [0, 2, 3]), kwargs = {})
    %div : [num_users=1] = call_function[target=torch.ops.aten.div.Tensor](args = (%sum_1, 48.0), kwargs = {})
    %broadcast_in_dim : [num_users=1] = call_function[target=torch.ops.prims.broadcast_in_dim.default](args = (%div, [1, %mul, 1, 1], [1]), kwargs = {})
    %sum_2 : [num_users=1] = call_function[target=torch.ops.prims.sum.default](args = (%reshape_default, [0, 2, 3]), kwargs = {})
    %broadcast_in_dim_1 : [num_users=1] = call_function[target=torch.ops.prims.broadcast_in_dim.default](args = (%sum_2, [1, %mul, 1, 1], [1]), kwargs = {})
    %mul_15 : [num_users=1] = call_function[target=operator.mul](args = (%sym_size_int_5, %sym_size_int_6), kwargs = {})
    %sym_float : [num_users=1] = call_function[target=torch.sym_float](args = (%mul_15,), kwargs = {})
    %div_1 : [num_users=1] = call_function[target=torch.ops.prims.div.default](args = (%broadcast_in_dim_1, %sym_float), kwargs = {})
    %add_4 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%broadcast_in_dim, 1e-05), kwargs = {})
    %sqrt : [num_users=1] = call_function[target=torch.ops.aten.sqrt.default](args = (%add_4,), kwargs = {})
    %div_2 : [num_users=1] = call_function[target=torch.ops.aten.div.Tensor](args = (1, %sqrt), kwargs = {})
    %sub_4 : [num_users=1] = call_function[target=torch.ops.aten.sub.Tensor](args = (%reshape_default, %div_1), kwargs = {})
    %mul_16 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%sub_4, %div_2), kwargs = {})
    %reshape_default_1 : [num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%mul_16, [%sym_size_int_4, 4, %sym_size_int_5, %sym_size_int_6]), kwargs = {})
    return (reshape_default_1,)
DEBUG:torch_tensorrt.dynamo.lowering.passes.accumulate_fp32_matmul:Skipping FP32 accumulation for matmul layers as use_fp32_acc is not enabled in the compilation settings
DEBUG:torch_tensorrt.dynamo._compiler:Lowered Input graph: graph():
    %x : [num_users=4] = placeholder[target=x]
    %sym_size_int_4 : [num_users=2] = call_function[target=torch.ops.aten.sym_size.int](args = (%x, 0), kwargs = {})
    %sym_size_int_5 : [num_users=3] = call_function[target=torch.ops.aten.sym_size.int](args = (%x, 2), kwargs = {})
    %sym_size_int_6 : [num_users=3] = call_function[target=torch.ops.aten.sym_size.int](args = (%x, 3), kwargs = {})
    %mul : [num_users=3] = call_function[target=operator.mul](args = (%sym_size_int_4, 4), kwargs = {})
    %reshape_default : [num_users=4] = call_function[target=torch.ops.aten.reshape.default](args = (%x, [1, %mul, %sym_size_int_5, %sym_size_int_6]), kwargs = {})
    %mean : [num_users=1] = call_function[target=torch.ops.aten.mean.dim](args = (%reshape_default, [0, 2, 3], True), kwargs = {})
    %sub_3 : [num_users=1] = call_function[target=torch.ops.aten.sub.Tensor](args = (%reshape_default, %mean), kwargs = {})
    %mul_14 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%sub_3, %sub_3), kwargs = {})
    %sum_1 : [num_users=1] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%mul_14, [0, 2, 3]), kwargs = {})
    %div : [num_users=1] = call_function[target=torch.ops.aten.div.Tensor](args = (%sum_1, 48.0), kwargs = {})
    %broadcast_in_dim : [num_users=1] = call_function[target=torch.ops.prims.broadcast_in_dim.default](args = (%div, [1, %mul, 1, 1], [1]), kwargs = {})
    %sum_2 : [num_users=1] = call_function[target=torch.ops.prims.sum.default](args = (%reshape_default, [0, 2, 3]), kwargs = {})
    %broadcast_in_dim_1 : [num_users=1] = call_function[target=torch.ops.prims.broadcast_in_dim.default](args = (%sum_2, [1, %mul, 1, 1], [1]), kwargs = {})
    %mul_15 : [num_users=1] = call_function[target=operator.mul](args = (%sym_size_int_5, %sym_size_int_6), kwargs = {})
    %sym_float : [num_users=1] = call_function[target=torch.sym_float](args = (%mul_15,), kwargs = {})
    %div_1 : [num_users=1] = call_function[target=torch.ops.prims.div.default](args = (%broadcast_in_dim_1, %sym_float), kwargs = {})
    %add_4 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%broadcast_in_dim, 1e-05), kwargs = {})
    %sqrt : [num_users=1] = call_function[target=torch.ops.aten.sqrt.default](args = (%add_4,), kwargs = {})
    %div_2 : [num_users=1] = call_function[target=torch.ops.aten.div.Tensor](args = (1, %sqrt), kwargs = {})
    %sub_4 : [num_users=1] = call_function[target=torch.ops.aten.sub.Tensor](args = (%reshape_default, %div_1), kwargs = {})
    %mul_16 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%sub_4, %div_2), kwargs = {})
    %reshape_default_1 : [num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%mul_16, [%sym_size_int_4, 4, %sym_size_int_5, %sym_size_int_6]), kwargs = {})
    return (reshape_default_1,)
DEBUG:torch_tensorrt.dynamo.partitioning._global_partitioner:
Supported Nodes:
- torch.ops.aten.sym_size.int + Operator Count: 3
- _operator.mul + Operator Count: 2
- torch.ops.aten.reshape.default + Operator Count: 2
- torch.ops.aten.mean.dim + Operator Count: 1
- torch.ops.aten.sub.Tensor + Operator Count: 2
- torch.ops.aten.mul.Tensor + Operator Count: 2
- torch.ops.aten.sum.dim_IntList + Operator Count: 1
- torch.ops.aten.div.Tensor + Operator Count: 2
- torch.ops.prims.sum.default + Operator Count: 1
- torch.ops.prims.div.default + Operator Count: 1
- torch.ops.aten.add.Tensor + Operator Count: 1
- torch.ops.aten.sqrt.default + Operator Count: 1

DEBUG:torch_tensorrt.dynamo.partitioning._global_partitioner:
Unsupported or Excluded Nodes:
- torch.ops.prims.broadcast_in_dim.default + Operator Count: 2
- torch.sym_float + Operator Count: 1

DEBUG:torch_tensorrt.dynamo._compiler:Detected support for 19 operators out of 22 in subgraph.
INFO:torch_tensorrt.dynamo._compiler:Partitioning the graph via the fast partitioner
DEBUG:torch_tensorrt.dynamo.partitioning._adjacency_partitioner:
Number of TensorRT-Accelerated Engines Generated: 2
DEBUG:torch_tensorrt.dynamo.partitioning._adjacency_partitioner:
Supported Nodes:
- torch.ops.aten.sym_size.int + Operator Count: 3
- _operator.mul + Operator Count: 2
- torch.ops.aten.reshape.default + Operator Count: 2
- torch.ops.aten.mean.dim + Operator Count: 1
- torch.ops.aten.sub.Tensor + Operator Count: 2
- torch.ops.aten.mul.Tensor + Operator Count: 2
- torch.ops.aten.sum.dim_IntList + Operator Count: 1
- torch.ops.aten.div.Tensor + Operator Count: 2
- torch.ops.prims.sum.default + Operator Count: 1
- torch.ops.prims.div.default + Operator Count: 1
- torch.ops.aten.add.Tensor + Operator Count: 1
- torch.ops.aten.sqrt.default + Operator Count: 1

DEBUG:torch_tensorrt.dynamo.partitioning._adjacency_partitioner:
Unsupported or Excluded Nodes:
- torch.ops.prims.broadcast_in_dim.default + Operator Count: 2
- torch.sym_float + Operator Count: 1

DEBUG:torch_tensorrt.dynamo._compiler:Updated metadata for node: _run_on_acc_0 with its corresponding submodule outputs
DEBUG:torch_tensorrt.dynamo._compiler:Converting submodule: _run_on_acc_0
 Input shapes: [{'min_shape': (1, 4, 1, 1), 'opt_shape': (2, 4, 6, 8), 'max_shape': (8, 4, 8, 8)}]
 graph():
    %x : [num_users=4] = placeholder[target=x]
    %sym_size_int_4 : [num_users=2] = call_function[target=torch.ops.aten.sym_size.int](args = (%x, 0), kwargs = {})
    %sym_size_int_5 : [num_users=3] = call_function[target=torch.ops.aten.sym_size.int](args = (%x, 2), kwargs = {})
    %sym_size_int_6 : [num_users=3] = call_function[target=torch.ops.aten.sym_size.int](args = (%x, 3), kwargs = {})
    %mul : [num_users=2] = call_function[target=operator.mul](args = (%sym_size_int_4, 4), kwargs = {})
    %reshape_default : [num_users=4] = call_function[target=torch.ops.aten.reshape.default](args = (%x, [1, %mul, %sym_size_int_5, %sym_size_int_6]), kwargs = {})
    %mean : [num_users=1] = call_function[target=torch.ops.aten.mean.dim](args = (%reshape_default, [0, 2, 3], True), kwargs = {})
    %sub_3 : [num_users=1] = call_function[target=torch.ops.aten.sub.Tensor](args = (%reshape_default, %mean), kwargs = {})
    %mul_14 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%sub_3, %sub_3), kwargs = {})
    %sum_1 : [num_users=1] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%mul_14, [0, 2, 3]), kwargs = {})
    %div : [num_users=1] = call_function[target=torch.ops.aten.div.Tensor](args = (%sum_1, 48.0), kwargs = {})
    %sum_2 : [num_users=1] = call_function[target=torch.ops.prims.sum.default](args = (%reshape_default, [0, 2, 3]), kwargs = {})
    %mul_15 : [num_users=1] = call_function[target=operator.mul](args = (%sym_size_int_5, %sym_size_int_6), kwargs = {})
    return (div, mul, sum_2, mul_15, reshape_default, sym_size_int_4, sym_size_int_5, sym_size_int_6)
DEBUG:torch_tensorrt.dynamo.conversion._TRTInterpreter:Converting node x (kind: x, args: ())
DEBUG:torch_tensorrt.dynamo.conversion._TRTInterpreter:Adding input to in-progress INetwork: x [shape=[-1, 4, -1, -1], dtype=DataType.FLOAT]
INFO:torch_tensorrt.dynamo.conversion._TRTInterpreter:Converted node x [x] (Inputs: () | Outputs: (x: (s0, 4, s1, s2)@torch.float32))
DEBUG:torch_tensorrt.dynamo.conversion._TRTInterpreter:Converting node _empty_nn_module_stack_from_metadata_hook/sym_size_int_4 (kind: aten.sym_size.int, args: ('x <Node>', '0 <int>'))
WARNING:torch_tensorrt.dynamo.conversion.converter_utils:Detected unparsable type in node formatting: <class 'torch.SymInt'>
INFO:torch_tensorrt.dynamo.conversion._TRTInterpreter:Converted node _empty_nn_module_stack_from_metadata_hook/sym_size_int_4 [aten.sym_size.int] (Inputs: (x: (s0, 4, s1, s2)@torch.float32, 0) | Outputs: (sym_size_int_4: ))
DEBUG:torch_tensorrt.dynamo.conversion._TRTInterpreter:Converting node _empty_nn_module_stack_from_metadata_hook/sym_size_int_5 (kind: aten.sym_size.int, args: ('x <Node>', '2 <int>'))
WARNING:torch_tensorrt.dynamo.conversion.converter_utils:Detected unparsable type in node formatting: <class 'torch.SymInt'>
INFO:torch_tensorrt.dynamo.conversion._TRTInterpreter:Converted node _empty_nn_module_stack_from_metadata_hook/sym_size_int_5 [aten.sym_size.int] (Inputs: (x: (s0, 4, s1, s2)@torch.float32, 2) | Outputs: (sym_size_int_5: ))
DEBUG:torch_tensorrt.dynamo.conversion._TRTInterpreter:Converting node _empty_nn_module_stack_from_metadata_hook/sym_size_int_6 (kind: aten.sym_size.int, args: ('x <Node>', '3 <int>'))
WARNING:torch_tensorrt.dynamo.conversion.converter_utils:Detected unparsable type in node formatting: <class 'torch.SymInt'>
INFO:torch_tensorrt.dynamo.conversion._TRTInterpreter:Converted node _empty_nn_module_stack_from_metadata_hook/sym_size_int_6 [aten.sym_size.int] (Inputs: (x: (s0, 4, s1, s2)@torch.float32, 3) | Outputs: (sym_size_int_6: ))
DEBUG:torch_tensorrt.dynamo.conversion._TRTInterpreter:Converting node m/mul (kind: <built-in function mul>, args: ('sym_size_int_4 <Node>', '4 <int>'))
WARNING:torch_tensorrt.dynamo.conversion.converter_utils:Detected unparsable type in node formatting: <class 'torch.SymInt'>
WARNING:torch_tensorrt.dynamo.conversion.converter_utils:Detected unparsable type in node formatting: <class 'torch.SymInt'>
INFO:torch_tensorrt.dynamo.conversion._TRTInterpreter:Converted node m/mul [<built-in function mul>] (Inputs: (sym_size_int_4: , 4) | Outputs: (mul: ))
DEBUG:torch_tensorrt.dynamo.conversion._TRTInterpreter:Converting node m/reshape_default (kind: aten.reshape.default, args: ('x <Node>', ['1 <int>', 'mul <Node>', 'sym_size_int_5 <Node>', 'sym_size_int_6 <Node>']))
INFO:torch_tensorrt.dynamo.conversion._TRTInterpreter:Converted node m/reshape_default [aten.reshape.default] (Inputs: (x: (s0, 4, s1, s2)@torch.float32, [1, mul, sym_size_int_5, sym_size_int_6]) | Outputs: (reshape_default: (1, 4*s0, s1, s2)@torch.float32))
DEBUG:torch_tensorrt.dynamo.conversion._TRTInterpreter:Converting node m/mean (kind: aten.mean.dim, args: ('reshape_default <Node>', ['0 <int>', '2 <int>', '3 <int>'], 'True <bool>'))
INFO:torch_tensorrt.dynamo.conversion._TRTInterpreter:Converted node m/mean [aten.mean.dim] (Inputs: (reshape_default: (1, 4*s0, s1, s2)@torch.float32, [0, 2, 3], True) | Outputs: (mean: (1, 4*s0, 1, 1)@torch.float32))
DEBUG:torch_tensorrt.dynamo.conversion._TRTInterpreter:Converting node m/sub_3 (kind: aten.sub.Tensor, args: ('reshape_default <Node>', 'mean <Node>'))
INFO:torch_tensorrt.dynamo.conversion._TRTInterpreter:Converted node m/sub_3 [aten.sub.Tensor] (Inputs: (reshape_default: (1, 4*s0, s1, s2)@torch.float32, mean: (1, 4*s0, 1, 1)@torch.float32) | Outputs: (sub_3: (1, 4*s0, s1, s2)@torch.float32))
DEBUG:torch_tensorrt.dynamo.conversion._TRTInterpreter:Converting node m/mul_14 (kind: aten.mul.Tensor, args: ('sub_3 <Node>', 'sub_3 <Node>'))
INFO:torch_tensorrt.dynamo.conversion._TRTInterpreter:Converted node m/mul_14 [aten.mul.Tensor] (Inputs: (sub_3: (1, 4*s0, s1, s2)@torch.float32, sub_3: (1, 4*s0, s1, s2)@torch.float32) | Outputs: (mul_14: (1, 4*s0, s1, s2)@torch.float32))
DEBUG:torch_tensorrt.dynamo.conversion._TRTInterpreter:Converting node m/sum_1 (kind: aten.sum.dim_IntList, args: ('mul_14 <Node>', ['0 <int>', '2 <int>', '3 <int>']))
INFO:torch_tensorrt.dynamo.conversion._TRTInterpreter:Converted node m/sum_1 [aten.sum.dim_IntList] (Inputs: (mul_14: (1, 4*s0, s1, s2)@torch.float32, [0, 2, 3]) | Outputs: (sum_1: (4*s0,)@torch.float32))
DEBUG:torch_tensorrt.dynamo.conversion._TRTInterpreter:Converting node m/div (kind: aten.div.Tensor, args: ('sum_1 <Node>', '48.0 <float>'))
INFO:torch_tensorrt.dynamo.conversion._TRTInterpreter:Converted node m/div [aten.div.Tensor] (Inputs: (sum_1: (4*s0,)@torch.float32, 48.0) | Outputs: (div: (4*s0,)@torch.float32))
DEBUG:torch_tensorrt.dynamo.conversion._TRTInterpreter:Converting node m/sum_2 (kind: prims.sum.default, args: ('reshape_default <Node>', ['0 <int>', '2 <int>', '3 <int>']))
INFO:torch_tensorrt.dynamo.conversion._TRTInterpreter:Converted node m/sum_2 [prims.sum.default] (Inputs: (reshape_default: (1, 4*s0, s1, s2)@torch.float32, [0, 2, 3]) | Outputs: (sum_2: (4*s0,)@torch.float32))
DEBUG:torch_tensorrt.dynamo.conversion._TRTInterpreter:Converting node m/mul_15 (kind: <built-in function mul>, args: ('sym_size_int_5 <Node>', 'sym_size_int_6 <Node>'))
WARNING:torch_tensorrt.dynamo.conversion.converter_utils:Detected unparsable type in node formatting: <class 'torch.SymInt'>
WARNING:torch_tensorrt.dynamo.conversion.converter_utils:Detected unparsable type in node formatting: <class 'torch.SymInt'>
WARNING:torch_tensorrt.dynamo.conversion.converter_utils:Detected unparsable type in node formatting: <class 'torch.SymInt'>
INFO:torch_tensorrt.dynamo.conversion._TRTInterpreter:Converted node m/mul_15 [<built-in function mul>] (Inputs: (sym_size_int_5: , sym_size_int_6: ) | Outputs: (mul_15: ))
DEBUG:torch_tensorrt.dynamo.conversion._TRTInterpreter:Converting node output (kind: output, args: (('div <Node>', 'mul <Node>', 'sum_2 <Node>', 'mul_15 <Node>', 'reshape_default <Node>', 'sym_size_int_4 <Node>', 'sym_size_int_5 <Node>', 'sym_size_int_6 <Node>'),))
Traceback (most recent call last):
  File "C:\Users\HolyWu\Downloads\test.py", line 25, in <module>
    trt_model = torch_tensorrt.compile(
                ^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Python312\Lib\site-packages\torch_tensorrt\_compile.py", line 286, in compile
    trt_graph_module = dynamo_compile(
                       ^^^^^^^^^^^^^^^
  File "C:\Python312\Lib\site-packages\torch_tensorrt\dynamo\_compiler.py", line 608, in compile
    trt_gm = compile_module(
             ^^^^^^^^^^^^^^^
  File "C:\Python312\Lib\site-packages\torch_tensorrt\dynamo\_compiler.py", line 810, in compile_module
    trt_module = convert_module(
                 ^^^^^^^^^^^^^^^
  File "C:\Python312\Lib\site-packages\torch_tensorrt\dynamo\conversion\_conversion.py", line 90, in convert_module
    interpreter_result = interpret_module_to_result(
                         ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Python312\Lib\site-packages\torch_tensorrt\dynamo\conversion\_conversion.py", line 69, in interpret_module_to_result
    interpreter_result = interpreter.run()
                         ^^^^^^^^^^^^^^^^^
  File "C:\Python312\Lib\site-packages\torch_tensorrt\dynamo\conversion\_TRTInterpreter.py", line 626, in run
    self._construct_trt_network_def()
  File "C:\Python312\Lib\site-packages\torch_tensorrt\dynamo\conversion\_TRTInterpreter.py", line 357, in _construct_trt_network_def
    super().run()
  File "C:\Python312\Lib\site-packages\torch\fx\interpreter.py", line 167, in run
    self.env[node] = self.run_node(node)
                     ^^^^^^^^^^^^^^^^^^^
  File "C:\Python312\Lib\site-packages\torch_tensorrt\dynamo\conversion\_TRTInterpreter.py", line 692, in run_node
    trt_node: torch.fx.Node = super().run_node(n)
                              ^^^^^^^^^^^^^^^^^^^
  File "C:\Python312\Lib\site-packages\torch\fx\interpreter.py", line 228, in run_node
    return getattr(self, n.op)(n.target, args, kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Python312\Lib\site-packages\torch_tensorrt\dynamo\conversion\_TRTInterpreter.py", line 855, in output
    raise RuntimeError(
RuntimeError: Specified output dtypes (3) differ from number of outputs (8)

While executing return (div, mul, sum_2, mul_15, reshape_default, sym_size_int_4, sym_size_int_5, sym_size_int_6)
Original traceback:
None

After Patch

DEBUG:torch_tensorrt.dynamo.lowering.passes.remove_detach:Removed 0 detach nodes:
graph():
    %x : [num_users=1] = placeholder[target=x]
    %instance_norm : [num_users=1] = call_function[target=torch.ops.aten.instance_norm.default](args = (%x, None, None, None, None, True, 0.1, 1e-05, True), kwargs = {})
    return (instance_norm,)
DEBUG:torch_tensorrt.dynamo._compiler:Input graph: graph():
    %x : [num_users=4] = placeholder[target=x]
    %sym_size_int_4 : [num_users=1] = call_function[target=torch.ops.aten.sym_size.int](args = (%x, 0), kwargs = {})
    %sym_size_int_5 : [num_users=1] = call_function[target=torch.ops.aten.sym_size.int](args = (%x, 2), kwargs = {})
    %sym_size_int_6 : [num_users=1] = call_function[target=torch.ops.aten.sym_size.int](args = (%x, 3), kwargs = {})
    %mul_3 : [num_users=1] = call_function[target=operator.mul](args = (%sym_size_int_5, %sym_size_int_6), kwargs = {})
    %native_group_norm : [num_users=1] = call_function[target=torch.ops.aten.native_group_norm.default](args = (%x, None, None, %sym_size_int_4, 4, %mul_3, 4, 1e-05), kwargs = {})
    %getitem : [num_users=1] = call_function[target=operator.getitem](args = (%native_group_norm, 0), kwargs = {})
    return (getitem,)
DEBUG:torch_tensorrt.dynamo.lowering.passes.constant_folding:Graph after constant folding:
graph():
    %x : [num_users=4] = placeholder[target=x]
    %sym_size_int_4 : [num_users=1] = call_function[target=torch.ops.aten.sym_size.int](args = (%x, 0), kwargs = {})
    %sym_size_int_5 : [num_users=1] = call_function[target=torch.ops.aten.sym_size.int](args = (%x, 2), kwargs = {})
    %sym_size_int_6 : [num_users=1] = call_function[target=torch.ops.aten.sym_size.int](args = (%x, 3), kwargs = {})
    %mul_3 : [num_users=1] = call_function[target=operator.mul](args = (%sym_size_int_5, %sym_size_int_6), kwargs = {})
    %native_group_norm : [num_users=1] = call_function[target=torch.ops.aten.native_group_norm.default](args = (%x, None, None, %sym_size_int_4, 4, %mul_3, 4, 1e-05), kwargs = {})
    %getitem : [num_users=1] = call_function[target=operator.getitem](args = (%native_group_norm, 0), kwargs = {})
    return (getitem,)
DEBUG:torch_tensorrt.dynamo.lowering.passes.remove_assert_scalar:Removed 0 assert_scalar nodes:
graph():
    %x : [num_users=4] = placeholder[target=x]
    %sym_size_int_4 : [num_users=1] = call_function[target=torch.ops.aten.sym_size.int](args = (%x, 0), kwargs = {})
    %sym_size_int_5 : [num_users=1] = call_function[target=torch.ops.aten.sym_size.int](args = (%x, 2), kwargs = {})
    %sym_size_int_6 : [num_users=1] = call_function[target=torch.ops.aten.sym_size.int](args = (%x, 3), kwargs = {})
    %mul_3 : [num_users=1] = call_function[target=operator.mul](args = (%sym_size_int_5, %sym_size_int_6), kwargs = {})
    %native_group_norm : [num_users=1] = call_function[target=torch.ops.aten.native_group_norm.default](args = (%x, None, None, %sym_size_int_4, 4, %mul_3, 4, 1e-05), kwargs = {})
    %getitem : [num_users=1] = call_function[target=operator.getitem](args = (%native_group_norm, 0), kwargs = {})
    return (getitem,)
DEBUG:torch_tensorrt.dynamo.lowering.passes.accumulate_fp32_matmul:Skipping FP32 accumulation for matmul layers as use_fp32_acc is not enabled in the compilation settings
DEBUG:torch_tensorrt.dynamo._compiler:Lowered Input graph: graph():
    %x : [num_users=4] = placeholder[target=x]
    %sym_size_int_4 : [num_users=1] = call_function[target=torch.ops.aten.sym_size.int](args = (%x, 0), kwargs = {})
    %sym_size_int_5 : [num_users=1] = call_function[target=torch.ops.aten.sym_size.int](args = (%x, 2), kwargs = {})
    %sym_size_int_6 : [num_users=1] = call_function[target=torch.ops.aten.sym_size.int](args = (%x, 3), kwargs = {})
    %mul_3 : [num_users=1] = call_function[target=operator.mul](args = (%sym_size_int_5, %sym_size_int_6), kwargs = {})
    %native_group_norm : [num_users=1] = call_function[target=torch.ops.aten.native_group_norm.default](args = (%x, None, None, %sym_size_int_4, 4, %mul_3, 4, 1e-05), kwargs = {})
    %getitem : [num_users=1] = call_function[target=operator.getitem](args = (%native_group_norm, 0), kwargs = {})
    return (getitem,)
DEBUG:torch_tensorrt.dynamo.partitioning._global_partitioner:
Supported Nodes:
- torch.ops.aten.sym_size.int + Operator Count: 3
- _operator.mul + Operator Count: 1
- torch.ops.aten.native_group_norm.default + Operator Count: 1
- _operator.getitem + Operator Count: 1

DEBUG:torch_tensorrt.dynamo.partitioning._global_partitioner:
All Nodes Supported

DEBUG:torch_tensorrt.dynamo._compiler:Detected support for 6 operators out of 6 in subgraph.
INFO:torch_tensorrt.dynamo._compiler:Partitioning the graph via the fast partitioner
DEBUG:torch_tensorrt.dynamo.partitioning._adjacency_partitioner:
Number of TensorRT-Accelerated Engines Generated: 1
DEBUG:torch_tensorrt.dynamo.partitioning._adjacency_partitioner:
Supported Nodes:
- torch.ops.aten.sym_size.int + Operator Count: 3
- _operator.mul + Operator Count: 1
- torch.ops.aten.native_group_norm.default + Operator Count: 1
- _operator.getitem + Operator Count: 1

DEBUG:torch_tensorrt.dynamo.partitioning._adjacency_partitioner:
All Nodes Supported

DEBUG:torch_tensorrt.dynamo._compiler:Updated metadata for node: _run_on_acc_0 with its corresponding submodule outputs
DEBUG:torch_tensorrt.dynamo._compiler:Converting submodule: _run_on_acc_0
 Input shapes: [{'min_shape': (1, 4, 1, 1), 'opt_shape': (2, 4, 6, 8), 'max_shape': (8, 4, 8, 8)}]
 graph():
    %x : [num_users=4] = placeholder[target=x]
    %sym_size_int_4 : [num_users=1] = call_function[target=torch.ops.aten.sym_size.int](args = (%x, 0), kwargs = {})
    %sym_size_int_5 : [num_users=1] = call_function[target=torch.ops.aten.sym_size.int](args = (%x, 2), kwargs = {})
    %sym_size_int_6 : [num_users=1] = call_function[target=torch.ops.aten.sym_size.int](args = (%x, 3), kwargs = {})
    %mul_3 : [num_users=1] = call_function[target=operator.mul](args = (%sym_size_int_5, %sym_size_int_6), kwargs = {})
    %native_group_norm : [num_users=1] = call_function[target=torch.ops.aten.native_group_norm.default](args = (%x, None, None, %sym_size_int_4, 4, %mul_3, 4, 1e-05), kwargs = {})
    %getitem : [num_users=1] = call_function[target=operator.getitem](args = (%native_group_norm, 0), kwargs = {})
    return getitem
DEBUG:torch_tensorrt.dynamo.conversion._TRTInterpreter:Converting node x (kind: x, args: ())
DEBUG:torch_tensorrt.dynamo.conversion._TRTInterpreter:Adding input to in-progress INetwork: x [shape=[-1, 4, -1, -1], dtype=DataType.FLOAT]
INFO:torch_tensorrt.dynamo.conversion._TRTInterpreter:Converted node x [x] (Inputs: () | Outputs: (x: (s0, 4, s1, s2)@torch.float32))
DEBUG:torch_tensorrt.dynamo.conversion._TRTInterpreter:Converting node _empty_nn_module_stack_from_metadata_hook/sym_size_int_4 (kind: aten.sym_size.int, args: ('x <Node>', '0 <int>'))
WARNING:torch_tensorrt.dynamo.conversion.converter_utils:Detected unparsable type in node formatting: <class 'torch.SymInt'>
INFO:torch_tensorrt.dynamo.conversion._TRTInterpreter:Converted node _empty_nn_module_stack_from_metadata_hook/sym_size_int_4 [aten.sym_size.int] (Inputs: (x: (s0, 4, s1, s2)@torch.float32, 0) | Outputs: (sym_size_int_4: ))
DEBUG:torch_tensorrt.dynamo.conversion._TRTInterpreter:Converting node _empty_nn_module_stack_from_metadata_hook/sym_size_int_5 (kind: aten.sym_size.int, args: ('x <Node>', '2 <int>'))
WARNING:torch_tensorrt.dynamo.conversion.converter_utils:Detected unparsable type in node formatting: <class 'torch.SymInt'>
INFO:torch_tensorrt.dynamo.conversion._TRTInterpreter:Converted node _empty_nn_module_stack_from_metadata_hook/sym_size_int_5 [aten.sym_size.int] (Inputs: (x: (s0, 4, s1, s2)@torch.float32, 2) | Outputs: (sym_size_int_5: ))
DEBUG:torch_tensorrt.dynamo.conversion._TRTInterpreter:Converting node _empty_nn_module_stack_from_metadata_hook/sym_size_int_6 (kind: aten.sym_size.int, args: ('x <Node>', '3 <int>'))
WARNING:torch_tensorrt.dynamo.conversion.converter_utils:Detected unparsable type in node formatting: <class 'torch.SymInt'>
INFO:torch_tensorrt.dynamo.conversion._TRTInterpreter:Converted node _empty_nn_module_stack_from_metadata_hook/sym_size_int_6 [aten.sym_size.int] (Inputs: (x: (s0, 4, s1, s2)@torch.float32, 3) | Outputs: (sym_size_int_6: ))
DEBUG:torch_tensorrt.dynamo.conversion._TRTInterpreter:Converting node m/mul_3 (kind: <built-in function mul>, args: ('sym_size_int_5 <Node>', 'sym_size_int_6 <Node>'))
WARNING:torch_tensorrt.dynamo.conversion.converter_utils:Detected unparsable type in node formatting: <class 'torch.SymInt'>
WARNING:torch_tensorrt.dynamo.conversion.converter_utils:Detected unparsable type in node formatting: <class 'torch.SymInt'>
WARNING:torch_tensorrt.dynamo.conversion.converter_utils:Detected unparsable type in node formatting: <class 'torch.SymInt'>
INFO:torch_tensorrt.dynamo.conversion._TRTInterpreter:Converted node m/mul_3 [<built-in function mul>] (Inputs: (sym_size_int_5: , sym_size_int_6: ) | Outputs: (mul_3: ))
DEBUG:torch_tensorrt.dynamo.conversion._TRTInterpreter:Converting node m/native_group_norm (kind: aten.native_group_norm.default, args: ('x <Node>', 'None <NoneType>', 'None <NoneType>', 'sym_size_int_4 <Node>', '4 <int>', 'mul_3 <Node>', '4 <int>', '1e-05 <float>'))
WARNING:torch_tensorrt.dynamo.conversion.converter_utils:Detected unparsable type in node formatting: <class 'torch.SymInt'>
WARNING:torch_tensorrt.dynamo.conversion.converter_utils:Detected unparsable type in node formatting: <class 'torch.SymInt'>
INFO:torch_tensorrt.dynamo.conversion._TRTInterpreter:Converted node m/native_group_norm [aten.native_group_norm.default] (Inputs: (x: (s0, 4, s1, s2)@torch.float32, None, None, sym_size_int_4: , 4, mul_3: , 4, 1e-05) | Outputs: (native_group_norm: ((s0, 4, s1, s2)@torch.float32, (s0, 4)@torch.float32, (s0, 4)@torch.float32)))
DEBUG:torch_tensorrt.dynamo.conversion._TRTInterpreter:Converting node m/getitem (kind: <built-in function getitem>, args: ('native_group_norm <Node>', '0 <int>'))
DEBUG:torch_tensorrt.dynamo.conversion.ops_evaluators:Evaluating _operator.getitem on object with name: m/getitem
INFO:torch_tensorrt.dynamo.conversion._TRTInterpreter:Converted node m/getitem [<built-in function getitem>] (Inputs: (native_group_norm: ((s0, 4, s1, s2)@torch.float32, (s0, 4)@torch.float32, (s0, 4)@torch.float32), 0) | Outputs: (getitem: (s0, 4, s1, s2)@torch.float32))
DEBUG:torch_tensorrt.dynamo.conversion._TRTInterpreter:Converting node output (kind: output, args: ('getitem <Node>',))
DEBUG:torch_tensorrt.dynamo.conversion._TRTInterpreter:Marking output output0 [shape=(-1, 4, -1, -1), dtype=DataType.FLOAT]
INFO:torch_tensorrt.dynamo.conversion._TRTInterpreter:Converted node output [output] (Inputs: (getitem: (s0, 4, s1, s2)@torch.float32) | Outputs: (output: ))
INFO:torch_tensorrt.dynamo.conversion._TRTInterpreter:TRT INetwork construction elapsed time: 0:00:00.005328
INFO:torch_tensorrt.dynamo.conversion._TRTInterpreter:Not found cached TRT engines. Start building engine.
INFO:torch_tensorrt.dynamo.conversion._TRTInterpreter:Build TRT engine elapsed time: 0:00:00.060676
INFO:torch_tensorrt.dynamo.conversion._TRTInterpreter:TRT Engine uses: 71644 bytes of Memory
DEBUG: [Torch-TensorRT] - Deserializing Device Info: 0%8%9%0%NVIDIA GeForce RTX 4060 Ti
DEBUG: [Torch-TensorRT] - Deserialized Device Info: Device(ID: 0, Name: NVIDIA GeForce RTX 4060 Ti, SM Capability: 8.9, Type: GPU)
DEBUG: [Torch-TensorRT] - Target Device: Device(ID: 0, Name: NVIDIA GeForce RTX 4060 Ti, SM Capability: 8.9, Type: GPU)
DEBUG: [Torch-TensorRT] - Setting Device(ID: 0, Name: NVIDIA GeForce RTX 4060 Ti, SM Capability: 8.9, Type: GPU) as active device
INFO: [Torch-TensorRT] - Loaded engine size: 0 MiB
DEBUG: [Torch-TensorRT] - Deserialization required 104 microseconds.
DEBUG: [Torch-TensorRT] - Total per-runner device persistent memory is 0
DEBUG: [Torch-TensorRT] - Total per-runner host persistent memory is 80
DEBUG: [Torch-TensorRT] - Allocated device scratch memory of size 512
DEBUG: [Torch-TensorRT] - - Runner scratch: 512 bytes
INFO: [Torch-TensorRT] - [MemUsageChange] TensorRT-managed allocation in IExecutionContext creation: CPU +0, GPU +0, now: CPU 0, GPU 0 (MiB)
DEBUG: [Torch-TensorRT] - CUDA lazy loading is enabled.
DEBUG: [Torch-TensorRT] - Input binding name: x has TensorRT binding index: 0, Torch binding index: 0
DEBUG: [Torch-TensorRT] - Output binding name: output0 has TensorRT binding index: 1, Torch binding index: 1
DEBUG: [Torch-TensorRT] - Torch-TensorRT TensorRT Engine:
  Name: _run_on_acc_0_engine
  Inputs: [
    id: 0
      name: x
      shape: [-1, 4, -1, -1]
      dtype: Float
  ]
  Outputs: [
    id: 0
      name: output0
      shape: [-1, 4, -1, -1]
      dtype: Float
  ]
  Device: Device(ID: 0, Name: NVIDIA GeForce RTX 4060 Ti, SM Capability: 8.9, Type: GPU)
  Hardware Compatibility: Disabled
  Target Platform: windows_x86_64

DEBUG:torch_tensorrt.dynamo._DryRunTracker:
++++++++++++++++++++++++++++++++++++++++++++++++++ Dry-Run Results for Graph ++++++++++++++++++++++++++++++++++++++++++++++++++

The graph consists of 6 Total Operators, of which 6 operators are supported, 100.0% coverage

Compiled with: CompilationSettings(enabled_precisions={<dtype.f32: 7>}, debug=True, workspace_size=0, min_block_size=1, torch_executed_ops=set(), pass_through_build_failures=False, max_aux_streams=None, version_compatible=False, optimization_level=None, use_python_runtime=False, truncate_double=False, use_fast_partitioner=True, enable_experimental_decompositions=False, device=Device(type=DeviceType.GPU, gpu_id=0), require_full_compilation=False, disable_tf32=False, assume_dynamic_shape_support=False, sparse_weights=False, make_refittable=False, engine_capability=<EngineCapability.STANDARD: 1>, num_avg_timing_iters=1, dla_sram_size=1048576, dla_local_dram_size=1073741824, dla_global_dram_size=536870912, dryrun=False, hardware_compatible=False, timing_cache_path='C:\\Users\\HolyWu\\AppData\\Local\\Temp\\torch_tensorrt_engine_cache\\timing_cache.bin', lazy_engine_init=False, cache_built_engines=False, reuse_cached_engines=False, use_explicit_typing=False, use_fp32_acc=False, enable_weight_streaming=False, enable_cross_compile_for_windows=False)

  Graph Structure:

   Inputs: List[Tensor: ((min=1, max=8), 4, (min=1, max=8), (min=1, max=8))@float32]
    ...
    TRT Engine #1 - Submodule name: _run_on_acc_0
     Engine Inputs: List[Tensor: ((min=1, max=8), 4, (min=1, max=8), (min=1, max=8))@float32]
     Number of Operators in Engine: 6
     Engine Outputs: List[Tensor: ((min=1, max=8), 4, (min=1, max=8), (min=1, max=8))@float32]
    ...
   Outputs: List[Tensor: ((min=1, max=8), 4, (min=1, max=8), (min=1, max=8))@float32]

  ------------------------- Aggregate Stats -------------------------

   Average Number of Operators per TRT Engine: 6.0
   Most Operators in a TRT Engine: 6

  ********** Recommendations **********

   - For minimal graph segmentation, select min_block_size=6 which would generate 1 TRT engine(s)
   - The current level of graph segmentation is equivalent to selecting min_block_size=6 which generates 1 TRT engine(s)
DEBUG: [Torch-TensorRT] - Attempting to run engine (ID: _run_on_acc_0_engine); Hardware Compatible: 0
DEBUG: [Torch-TensorRT] - Input Name: x Shape: [3, 4, 5, 6]
DEBUG: [Torch-TensorRT] - Output Name: output0 Shape: [3, 4, 5, 6]
assert_close passed

Copy link
Collaborator

@zewenli98 zewenli98 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @HolyWu , this is a great improvement! LGTM, let's merge this PR after #3273 gets merged.

@apbose
Copy link
Collaborator

apbose commented Nov 14, 2024

Thanks for the PR @HolyWu. Overall the PR looks good to me too.
Just had a small doubt in the dynamic case, in this

inputs = [
        torch_tensorrt.Input(min_shape=(2, 4, 2, 2), opt_shape=(2, 4, 6, 8), max_shape=(8, 4, 8, 8), dtype=torch.float)
    ]
``` you are comparing it against input [torch.randn((3, 4, 5, 6), device="cuda")]. So what input does it take for the dynamic case? And why does it return 8 outputs there?

@HolyWu
Copy link
Contributor Author

HolyWu commented Nov 14, 2024

inputs = [
        torch_tensorrt.Input(min_shape=(2, 4, 2, 2), opt_shape=(2, 4, 6, 8), max_shape=(8, 4, 8, 8), dtype=torch.float)
    ]

you are comparing it against input [torch.randn((3, 4, 5, 6), device="cuda")]. So what input does it take for the dynamic case? And why does it return 8 outputs there?

I'm not sure I understand your question. (3, 4, 5, 6) is inside the range of min_shape and max_shape. I don't see anything wrong about it? I also don't see where it returned 8 outputs in the log?

@apbose
Copy link
Collaborator

apbose commented Nov 14, 2024

Ok I missed that you are using the same inputs for the torchtrt compiled model and the torch model. As in the inputs is redefined.

Regarding the 8 outputs in the log, I am talking about this in the dynamic case before patch

 File "C:\Python312\Lib\site-packages\torch_tensorrt\dynamo\conversion\_TRTInterpreter.py", line 855, in output
    raise RuntimeError(
RuntimeError: Specified output dtypes (3) differ from number of outputs (8)

While executing return (div, mul, sum_2, mul_15, reshape_default, sym_size_int_4, sym_size_int_5, sym_size_int_6)
Original traceback:

Looks like in the case of the nodes with dynamic shape with no meta information, the output_dtypes is empty resulting in just 3 elements there.
Anyways since now it is not decomposing it in this way, there is no error.

@zewenli98 zewenli98 merged commit 3e8d735 into pytorch:main Nov 15, 2024
51 of 58 checks passed
@HolyWu HolyWu deleted the instance_norm_decomposition branch November 16, 2024 13:15
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
cla signed component: api [Python] Issues re: Python API component: dynamo Issues relating to the `torch.compile` or `torch._dynamo.export` paths component: lowering Issues re: The lowering / preprocessing passes component: tests Issues re: Tests
Projects
None yet
Development

Successfully merging this pull request may close these issues.

✨[Feature] Possibility to export nn.InstanceNorm2d
4 participants