Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for converting a inpainting model to ONNX and TensorRT #1831

Merged
merged 15 commits into from
Mar 29, 2023

Conversation

KKIEEK
Copy link
Contributor

@KKIEEK KKIEEK commented Mar 4, 2023

Modification

I added a new feature to convert inpainting models in MMEditing to ONNX and TensorRT.

However, I can't provide quantitative evaluation because of the high computational cost required for CelebA and Places2. (These are too large)
If you have any idea, please let me know.

Use cases (Optional)

# Download
pip install openmim
mim download mmedit --config AOT-GAN_512x512_4x12_places --dest aot-gan
mim download mmedit --config deepfillv1_256x256_8x2_places --dest deepfillv1
mim download mmedit --config deepfillv2_256x256_8x2_places --dest deepfillv2
mim download mmedit --config gl_256x256_8x12_places --dest global_local
mim download mmedit --config pconv_256x256_stage2_4x2_places --dest partial_conv

# Convert PyTorch to ONNX
pip install pycuda onnxruntime-gpu
python tools/torch2onnx.py configs/mmedit/inpainting/inpainting_onnxruntime_static.py deepfillv2/deepfillv2_256x256_8x2_places.py deepfillv2/deepfillv2_256x256_8x2_places_20200619-10d15793.pth tests/data/tiger.jpeg
...

# Convert ONNX to TensorRT
pip install tensorrt
python tools/onnx2tensorrt.py configs/mmedit/inpainting/inpainting_tensorrt_static-256x256.py work-dir/end2end.onnx tensorrt --device-id 0
...

@CLAassistant
Copy link

CLAassistant commented Mar 4, 2023

CLA assistant check
All committers have signed the CLA.

@lvhan028
Copy link
Collaborator

lvhan028 commented Mar 6, 2023

Hi, @KKIEEK
Since we are going to release openmmlab 2.0 soon, i.e., mmediting 1.x, mmdeploy 1.x, and master will be a legacy, could you work on mmediting 1.x and pull this PR to mmdeploy's dev-1.x?

@KKIEEK
Copy link
Contributor Author

KKIEEK commented Mar 6, 2023

@lvhan028
Sorry, since I'm still using mmedit 0.x, it would be difficult for me to work on the 1.x branch.
Is there no plan to support legacy version in the future?

@lvhan028
Copy link
Collaborator

lvhan028 commented Mar 6, 2023

@lvhan028 Sorry, since I'm still using mmedit 0.x, it would be difficult for me to work on the 1.x branch. Is there no plan to support legacy version in the future?

Don't worry. We still support the legacy version.

@grimoire
Copy link
Member

grimoire commented Mar 14, 2023

The use cases failed with log:

[03/14/2023-20:00:18] [TRT] [E] [shuffleNode.cpp::symbolicExecute::392] Error Code 4: Internal Error (Reshape_13: IShuffleLayer applied to shape tensor must have 0 or 1 reshape dimensions: dimensions were [-1,2])
[03/14/2023-20:00:18] [TRT] [E] ModelImporter.cpp:773: While parsing node number 19 [Pad -> "145"]:
[03/14/2023-20:00:18] [TRT] [E] ModelImporter.cpp:774: --- Begin node ---
[03/14/2023-20:00:18] [TRT] [E] ModelImporter.cpp:775: input: "112"
input: "144"
output: "145"
name: "Pad_19"
op_type: "Pad"
attribute {
  name: "mode"
  s: "reflect"
  type: STRING
}

[03/14/2023-20:00:18] [TRT] [E] ModelImporter.cpp:776: --- End node ---
[03/14/2023-20:00:18] [TRT] [E] ModelImporter.cpp:779: ERROR: ModelImporter.cpp:180 In function parseGraph:
[6] Invalid Node - Pad_19
[shuffleNode.cpp::symbolicExecute::392] Error Code 4: Internal Error (Reshape_13: IShuffleLayer applied to shape tensor must have 0 or 1 reshape dimensions: dimensions were [-1,2])
RuntimeError: Failed to parse onnx, In node 19 (parseGraph): INVALID_NODE: Invalid Node - Pad_19
[shuffleNode.cpp::symbolicExecute::392] Error Code 4: Internal Error (Reshape_13: IShuffleLayer applied to shape tensor must have 0 or 1 reshape dimensions: dimensions were [-1,2])

My envs:

2023-03-14 20:03:03,743 - mmdeploy - INFO - TorchVision: 0.11.2+cu113
2023-03-14 20:03:03,743 - mmdeploy - INFO - OpenCV: 4.5.4
2023-03-14 20:03:03,743 - mmdeploy - INFO - MMCV: 1.7.1
2023-03-14 20:03:03,743 - mmdeploy - INFO - MMCV Compiler: GCC 9.4
2023-03-14 20:03:03,743 - mmdeploy - INFO - MMCV CUDA Compiler: 11.3
2023-03-14 20:03:03,743 - mmdeploy - INFO - MMDeploy: 0.13.0+3682c60
2023-03-14 20:03:03,743 - mmdeploy - INFO - 

2023-03-14 20:03:03,743 - mmdeploy - INFO - **********Backend information**********
2023-03-14 20:03:03,782 - mmdeploy - INFO - tensorrt:   8.5.1.7
2023-03-14 20:03:03,782 - mmdeploy - INFO - tensorrt custom ops:        Available
2023-03-14 20:03:03,900 - mmdeploy - INFO - ONNXRuntime:        None
2023-03-14 20:03:03,900 - mmdeploy - INFO - ONNXRuntime-gpu:    1.13.1
2023-03-14 20:03:03,900 - mmdeploy - INFO - ONNXRuntime custom ops:     Available

2023-03-14 20:03:04,587 - mmdeploy - INFO - **********Codebase information**********
2023-03-14 20:03:05,683 - mmdeploy - INFO - mmedit:     0.16.1

@KKIEEK
Copy link
Contributor Author

KKIEEK commented Mar 14, 2023

Honestly speaking, I don't know how to solve this problem.
I aligned the versions of TensorRT and ONNX Runtime, but it still worked well in my case.
My environment is as follows: (I tested it in colab environment)

2023-03-14 13:09:52,235 - mmdeploy - INFO - **********Environmental information**********
2023-03-14 13:09:52,612 - mmdeploy - INFO - sys.platform: linux
2023-03-14 13:09:52,613 - mmdeploy - INFO - Python: 3.9.16 (main, Dec  7 2022, 01:11:51) [GCC 9.4.0]
2023-03-14 13:09:52,613 - mmdeploy - INFO - CUDA available: True
2023-03-14 13:09:52,613 - mmdeploy - INFO - GPU 0: Tesla T4
2023-03-14 13:09:52,613 - mmdeploy - INFO - CUDA_HOME: /usr/local/cuda
2023-03-14 13:09:52,613 - mmdeploy - INFO - NVCC: Cuda compilation tools, release 11.8, V11.8.89
2023-03-14 13:09:52,613 - mmdeploy - INFO - GCC: x86_64-linux-gnu-gcc (Ubuntu 9.4.0-1ubuntu1~20.04.1) 9.4.0
2023-03-14 13:09:52,613 - mmdeploy - INFO - PyTorch: 1.13.1+cu116
2023-03-14 13:09:52,613 - mmdeploy - INFO - PyTorch compiling details: PyTorch built with:
...
2023-03-14 13:09:52,613 - mmdeploy - INFO - TorchVision: 0.14.1+cu116
2023-03-14 13:09:52,613 - mmdeploy - INFO - OpenCV: 4.6.0
2023-03-14 13:09:52,613 - mmdeploy - INFO - MMCV: 1.7.1
2023-03-14 13:09:52,613 - mmdeploy - INFO - MMCV Compiler: GCC 9.3
2023-03-14 13:09:52,613 - mmdeploy - INFO - MMCV CUDA Compiler: 11.6
2023-03-14 13:09:52,613 - mmdeploy - INFO - MMDeploy: 0.13.0+3682c60
2023-03-14 13:09:52,613 - mmdeploy - INFO - 

2023-03-14 13:09:52,613 - mmdeploy - INFO - **********Backend information**********
2023-03-14 13:09:52,662 - mmdeploy - INFO - tensorrt:	8.5.3.1
2023-03-14 13:09:52,662 - mmdeploy - INFO - tensorrt custom ops:	NotAvailable
2023-03-14 13:09:52,737 - mmdeploy - INFO - ONNXRuntime:	None
2023-03-14 13:09:52,737 - mmdeploy - INFO - ONNXRuntime-gpu:	1.14.1
2023-03-14 13:09:52,737 - mmdeploy - INFO - ONNXRuntime custom ops:	NotAvailable

2023-03-14 13:09:52,866 - mmdeploy - INFO - **********Codebase information**********
2023-03-14 13:09:52,868 - mmdeploy - INFO - mmedit:	0.16.1

@grimoire
Copy link
Member

@lvhan028 @irexyc Please have a try

@irexyc
Copy link
Collaborator

irexyc commented Mar 15, 2023

2023-03-14 13:09:52,613 - mmdeploy - INFO - PyTorch: 1.13.1+cu116
2023-03-14 13:09:52,613 - mmdeploy - INFO - TorchVision: 0.14.1+cu116
2023-03-14 13:09:52,662 - mmdeploy - INFO - tensorrt: 8.5.3.1

2023-03-14 13:09:52,613 - mmdeploy - INFO - PyTorch: 1.10.1
2023-03-14 13:09:52,613 - mmdeploy - INFO - TorchVision: 0.11.2
2023-03-14 13:09:52,662 - mmdeploy - INFO - tensorrt: 8.5.1.7

both works

@lvhan028
Copy link
Collaborator

Hi, @KKIEEK
Could you update the supported models in docs/en/04-supported-codebases/mmedit.md?
Since the quantitative evaluation cannot be provided, I think it's better to put '*' behind the model and give the notification

@KKIEEK
Copy link
Contributor Author

KKIEEK commented Mar 21, 2023

@lvhan028 I updated table in docs/en/04-supported-codebases/mmedit.md

@lvhan028
Copy link
Collaborator

medit 0.x, it would be difficult for me to work o

I met the same issue when I tried it with pytorch1.8 and pytorch1.10.1+cu111

[shuffleNode.cpp::symbolicExecute::387] Error Code 4: Internal Error (Reshape_14: IShuffleLayer applied to shape tensor must have 0 or 1 reshape dimensions: dimensions were [-1,2])

@KKIEEK
Copy link
Contributor Author

KKIEEK commented Mar 22, 2023

I found a related issue NVIDIA/TensorRT#2484
@lvhan028 Could you try again using tensorrt==8.5.x.x without building a custom TRT?

@lvhan028
Copy link
Collaborator

I found a related issue NVIDIA/TensorRT#2484 @lvhan028 Could you try again using tensorrt==8.5.x.x without building a custom TRT?

Yes. After upgrading tensorrt from 8.2.3.0 to 8.5.3.1, tools/onnx2tensorrt.py runs successfully.

@lvhan028
Copy link
Collaborator

visualize on tensorrt model got crash

@KKIEEK
Copy link
Contributor Author

KKIEEK commented Mar 23, 2023

@lvhan028
Could you show me the error log? In my case, there are some bug (e.g. reversed channel order), but no crash in the api visualize_model

@lvhan028
Copy link
Collaborator

Traceback (most recent call last):
  File "/home/PJLAB/lvhan/Documents/projects/open-mmlab/mmdeploy/mmdeploy/utils/utils.py", line 41, in target_wrapper
    result = target(*args, **kwargs)
  File "/home/PJLAB/lvhan/Documents/projects/open-mmlab/mmdeploy/mmdeploy/apis/visualize.py", line 72, in visualize_model
    result = task_processor.run_inference(model, model_inputs)[0]
  File "/home/PJLAB/lvhan/Documents/projects/open-mmlab/mmdeploy/mmdeploy/codebase/mmedit/deploy/inpainting.py", line 219, in run_inference
    if not isinstance(results[0], np.ndarray):
KeyError: 0
2023-03-23 16:12:53,956 - mmdeploy - ERROR - visualize pytorch model failed.

@lvhan028
Copy link
Collaborator

This is my test command:

python tools/deploy.py configs/mmedit/inpainting/inpainting_tensorrt-fp16_static-256x256.py /data1/checkpoint/mmedit/deepfillv2/deepfillv2_256x256_8x2_places.py /data1/checkpoint/mmedit/deepfillv2/deepfillv2_256x256_8x2_places_20200619-10d15793.pth --device cuda demo/resources/det.jpg --work-dir /data1/mmdeploy_models/mmedit/trt/deepfillv2

@KKIEEK
Copy link
Contributor Author

KKIEEK commented Mar 24, 2023

I fixed it. @lvhan028
And, as far as I remember, converting DeepFill to TensorRT with fp16 does not work correctly, due to intermediate values in softmax exceeding the fp16 range in the ContextualAttention module. For more details, please refer to this line.
Except for DeepFill, converting to TensorRT with fp16 works well.

PyTorch (DeepFillv2)

output_pytorch

TensorRT (fp32)

output_tensorrt

TensorRT (fp16)

output_tensorrt-2

@lvhan028
Copy link
Collaborator

lvhan028 commented Mar 27, 2023

I fixed it. @lvhan028 And, as far as I remember, converting DeepFill to TensorRT with fp16 does not work correctly, due to intermediate values in softmax exceeding the fp16 range in the ContextualAttention module. For more details, please refer to this line. Except for DeepFill, converting to TensorRT with fp16 works well.

PyTorch (DeepFillv2)

output_pytorch

TensorRT (fp32)

output_tensorrt

TensorRT (fp16)

output_tensorrt-2

As shown in this experiment, the output of pytorch model and the output of tensorrt(fp32) model are different.
That's because of RandomResizeCrop or Crop with random=True
I think we need to remove them from test_pipeline.

@KKIEEK
Copy link
Contributor Author

KKIEEK commented Mar 27, 2023

I have some questions:

  1. LoadMask also randomly transforms the input image. Should I modify the LoadMask pipeline too? As far as I know, in current MMDeploy implementation, there is no way to input the mask.
  2. In MMEdit, it seems that test_pipeline is used to evaluate a test dataset during training. However, for inference, MMEdit defines a new pipeline as this link. How can I seperate two pipeline?

@lvhan028
Copy link
Collaborator

lvhan028 commented Mar 28, 2023

I have some questions:

  1. LoadMask also randomly transforms the input image. Should I modify the LoadMask pipeline too? As far as I know, in current MMDeploy implementation, there is no way to input the mask.
  2. In MMEdit, it seems that test_pipeline is used to evaluate a test dataset during training. However, for inference, MMEdit defines a new pipeline as this link. How can I seperate two pipeline?

I see. From the perspective of converting an inpainting pytorch model to backend model, I think this PR works and I am gonna approve it. Good job!

Regarding the actual model inference, this pipeline seems necessary. We will consider how to support it.

@KKIEEK
Copy link
Contributor Author

KKIEEK commented Mar 28, 2023

I appreciate for your hard work, please tell me whenever you need any help about this PR.

@lvhan028 lvhan028 merged commit f7c484a into open-mmlab:master Mar 29, 2023
@RunningLeon
Copy link
Collaborator

@KKIEEK Hi, many thanks for your PR. If you have time and interest in adding this feature to mmdeploy main branch (which is based on mmagic), pls. let me know. BR.

@KKIEEK
Copy link
Contributor Author

KKIEEK commented Aug 24, 2023

@RunningLeon Of course, I'll try it if you need this feature. However, I think it takes at least 3 weeks to draft PR. Could you wait for this?

@RunningLeon
Copy link
Collaborator

@KKIEEK Hi, Great. No hurry. If you have any questions, pls. let me know. Thanks in advance.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants