Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

🐛 [Bug] Parse failure in fallback cases #1112

Closed
Njuapp opened this issue Jun 13, 2022 · 2 comments
Closed

🐛 [Bug] Parse failure in fallback cases #1112

Njuapp opened this issue Jun 13, 2022 · 2 comments
Assignees
Labels
bug Something isn't working component: partitioning

Comments

@Njuapp
Copy link
Contributor

Njuapp commented Jun 13, 2022

Bug Description

When fallback is encountered, the torchscript may contain . in the name of parameters, e.g., x.1 and input.1. Such torchscript models, after saved to disk, cannot be loaded successfully because paramter names containing . is invalid.

Traceback (most recent call last):
  File "main.py", line 56, in <module>
    main()
  File "main.py", line 39, in main
    model = torch.jit.load('resnet_withfallback.trt.ts')
  File "/opt/conda/lib/python3.8/site-packages/torch/jit/_serialization.py", line 162, in load
    cpp_module = torch._C.import_ir_module(cu, str(f), map_location, _extra_files)
RuntimeError: expected ) but found 'number' here:
Serialized   File "code/__torch__/torchvision/models/resnet.py", line 7
  __torch___torchvision_models_resnet_ResNet_trt_engine_0x5595a2dc5070 : __torch__.torch.classes.tensorrt.Engine
  def forward(self_1: __torch__.torchvision.models.resnet.ResNet_trt,
    x.1: Tensor) -> Tensor:
     ~~ <--- HERE
    __torch___torchvision_models_resnet_ResNet_trt_engine_0x5595a2dc4ed0 = self_1.__torch___torchvision_models_resnet_ResNet_trt_engine_0x5595a2dc4ed0
    _0 = ops.tensorrt.execute_engine([x.1], __torch___torchvision_models_resnet_ResNet_trt_engine_0x5595a2dc4ed0)

To Reproduce

Run the following python file:

import torch
import torch_tensorrt
import torchvision
import time

torch_tensorrt.logging.set_reportable_log_level(torch_tensorrt.logging.Level.Debug)

warmup_time = 10
test_time = 100
def main():
    torch.manual_seed(2022)
    im_rand = torch.rand((16, 3, 224, 224)).cuda()
    resnet_model = torchvision.models.resnet50()
    resnet_model = resnet_model.eval().cuda()
    torch_script_module = torch.jit.trace(resnet_model, (im_rand))
    # torch.jit.save(torch_script_module.half(), 'resnet18.pt')

    # torch_script_module = torch.jit.load('resnet.pt')
    input = torch.rand((32, 3, 224, 224)).cuda()
    ts_output = torch_script_module(input)
    print('result:', ts_output)

    compile_settings = {
        # "inputs": [input],
        "inputs": [torch_tensorrt.Input(
            # min_shape=[1, 3, 224, 224],  # TODO: depends on the model size
            # opt_shape=[32, 3, 224, 224],
            # max_shape=[32, 3, 224, 224],
            shape=[1, 3, 224, 224],
            dtype=torch.half,  # Datatype of input tensor. Allowed options torch.(float|half|int8|int32|bool)
        )],
        'torch_executed_ops': ['aten::adaptive_avg_pool2d'],
        "enabled_precisions": {torch.half}
    }

    print("Compile: FP16")
    model = torch_tensorrt.compile(torch_script_module.half(), **compile_settings)
    torch.jit.save(model, 'resnet_withfallback.trt.ts')
    model = torch.jit.load('resnet_withfallback.trt.ts')

    for i in range(warmup_time):
        trt_output = model(input.half())

    print("Compile: FP16")
    t1 = time.time()
    for i in range(test_time):
        trt_output = model(input.half())
    t2 = time.time()
    print('trt_output: ', trt_output)

    diff = abs(ts_output - trt_output)
    print('diff: {}'.format(diff.mean()))


if __name__ == '__main__':
    main()

Expected behavior

The compilation should end without failure.

Environment

Build information about Torch-TensorRT can be found by turning on debug messages

  • Torch-TensorRT Version (e.g. 1.1.0):
  • PyTorch Version (e.g. 1.11):
  • CPU Architecture: x86_64
  • OS (e.g., Linux): Linux
  • How you installed PyTorch (conda, pip, libtorch, source): pip
  • Build command you used (if compiling from source): bazel
  • Python version: 3.8
  • CUDA version: 11.3
@Njuapp Njuapp added the bug Something isn't working label Jun 13, 2022
@bowang007
Copy link
Collaborator

similar to #973

@bowang007
Copy link
Collaborator

closing per #1148

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working component: partitioning
Projects
None yet
Development

Successfully merging a pull request may close this issue.

4 participants