Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Polygraphy] VRAM Overflow Without Context Manager or 3X Performance Penalty With Context Manager #3836

Closed
NevermindNilas opened this issue May 1, 2024 · 8 comments

Comments

@NevermindNilas
Copy link

NevermindNilas commented May 1, 2024

Description

Issue stands as follow, as far as my understanding gets me, you have to .activate() and .deactivate() the generated runner class through:

# Method #1
with TrtRunner(self.engine) as runner:
    outputFrame = runner.infer({"input": inputFrame}, ["output"])["output"]

# Method #2
self.runner = TrtRunner(self.engine)

... 
# In a different function in my case
self.runner.activate()
outputFrame = runner.infer({"input": inputFrame}, ["output"])["output"] # I just want the output Tensor generated from the inference.
self.runner.deactivate()
# Where inputFrame and outputFrame are a torch tensor ( 1, 3, height, width ) with FP16 enabled

Problem stands, with the above ^ implementation, the performance penalty of constantly activating and deactivating the engine results in similar if not lower performance than just using pytorch-cuda for example.

One solution to bypass the performance penalty would be just doing:

self.runner.activate()

And then 'never' deactivating it, but of course as the docs indicate this can and will cause immense VRAM Overflows. I assume this is caused because of the engine context just allocating a new memory block for each inference and holding on to the last one until the script ends.

My question is, is there a best of both worlds where you both bypass the performance penalty and the VRAM Overflow.

Environment

TensorRT Version: 10.x ( latest available through pip )

NVIDIA GPU: Rtx 3090

NVIDIA Driver Version: 552.22

CUDA Version: Cuda compilation tools, release 12.4, V12.4.131

CUDNN Version: Honestly, no idea.

Operating System: Windows 11

Python Version (if applicable): 3.11.9

PyTorch Version (if applicable): 2.2.2 + cu12.1

Baremetal or Container (if so, version): Baremetal

Relevant Files

Model link:
https://github.com/NevermindNilas/TAS-Modes-Host/releases/download/main/2x_ModernSpanimationV1_fp16_op17.onnx

Steps To Reproduce

Most of it can be found above

Commands or scripts:

Have you tried the latest release?: Yes

Can this model run on other frameworks? For example run ONNX model with ONNXRuntime (polygraphy run <model.onnx> --onnxrt): Yes

@pranavm-nvidia
Copy link
Collaborator

infer() shouldn't usually do any allocations after the first call. The only exception to that would be if you have a model with dynamic shapes and feed in larger shapes on each subsequent call, in which case it would need to reallocate in order to accommodate the larger tensors.

Which version of Polygraphy are you using?

@NevermindNilas
Copy link
Author

NevermindNilas commented May 3, 2024

infer() shouldn't usually do any allocations after the first call. The only exception to that would be if you have a model with dynamic shapes and feed in larger shapes on each subsequent call, in which case it would need to reallocate in order to accommodate the larger tensors.

Which version of Polygraphy are you using?

Hello, I am currently using the latest available through pip that being 0.49.x ( sorry I am on my phone so I don't know the exact version past that )

In my particular use case, I input a torch tensor of 1,3,1088,1920 ( added padding from 1080 -> 1088 ) and expect a torch tensor of 1,3,2176, 3840 ( which i then slice from 2176 -> 2160 )

I have only tried building it with a static profile of min opt and max being the same resolutions as the input 1,3,1088,1920.

I will do a bit more testing and come back to this with my full implementation a bit later.

Thank you.

@NevermindNilas
Copy link
Author

NevermindNilas commented May 3, 2024

Hi, I am back with some example code:

Method 1, call .activate() once and never .deactivate()

        self.isCudaAvailable = torch.cuda.is_available() # Check if CUDA is available
        self.modelPath = r"G:\TheAnimeScripter\src\weights\span-directml\2x_ModernSpanimationV1_fp16_op17.onnx" # Path to the model file


        self.device = torch.device("cuda" if self.isCudaAvailable else "cpu") # Set the device to CUDA if available, else CPU
        if self.isCudaAvailable:
            # self.stream = [torch.cuda.Stream() for _ in range(self.nt)]
            # self.currentStream = 0
            torch.backends.cudnn.enabled = True
            torch.backends.cudnn.benchmark = True
            if self.half:
                torch.set_default_dtype(torch.float16)

        profiles = [
            # The low-latency case. For best performance, min == opt == max.
            Profile().add(
                "input",
                min=(1, 3, self.height, self.width),
                opt=(1, 3, self.height, self.width),
                max=(1, 3, self.height, self.width),
            ),
        ]
        self.engine = engine_from_network(
            network_from_onnx_path(self.modelPath),
            config=CreateConfig(fp16=True, profiles=profiles),
        )

        self.runner = TrtRunner(self.engine) # Create a TensorRT context
        self.runner.activate() # Activate the context
    
    @torch.inference_mode()
    def run(self, frame: np.ndarray) -> np.ndarray:
        frame = (
            torch.from_numpy(frame)
            .permute(2, 0, 1)
            .unsqueeze(0)
            .float()
            .mul_(1 / 255)
        ) # norm ops from np.uint8 to torch.float32

        frame = frame.half() if self.half and self.isCudaAvailable else frame # Convert to half precision if half is enabled and CUDA is available
        frame = self.runner.infer({"input": frame}, ["output"])["output"] # Run the model with TensorRT, outputs only the necessary output tensor

        return (
            frame.squeeze(0)
            .permute(1, 2, 0)
            .mul_(255)
            .clamp(1, 255)
            .byte()
            .cpu()
            .numpy()
        )

This as expected creates a VRAM Overflow as seen here: ( big spike is the relevant part, everything before it is the Engine Building )

image

Polygraphy also alerts me with:
[W] trt-runner-N0-05/04/24-02:29:29 | Was activated but never deactivated. This could cause a memory leak!

FFMPEG reported FPS for this is:
frame= 240 fps= 25 q=-1.0 Lsize= 18643KiB time=00:00:09.92 bitrate=15385.1kbits/s speed=1.02x

Method #2, With the with TrtRunner(self.engine) as runner:

    @torch.inference_mode()
    def run(self, frame: np.ndarray) -> np.ndarray:
        frame = (
            torch.from_numpy(frame)
            .permute(2, 0, 1)
            .unsqueeze(0)
            .float()
            .mul_(1 / 255)
        ) # norm ops from np.uint8 to torch.float32

        frame = frame.half() if self.half and self.isCudaAvailable else frame # Convert to half precision if half is enabled and CUDA is available

        with TrtRunner(self.engine) as runner:
            frame = runner.infer({"input": frame}, ["output"])["output"] # Run the model with TensorRT, outputs only the necessary output tensor

        return (
            frame.squeeze(0)
            .permute(1, 2, 0)
            .mul_(255)
            .clamp(1, 255)
            .byte()
            .cpu()
            .numpy()
        )

VRAM usage throughout a 240 frames test:
image

FFMPEG Reported performance:
frame= 240 fps=7.1 q=-1.0 Lsize= 19256KiB time=00:00:09.92 bitrate=15891.2kbits/s speed=0.296x

I have also tried with self.runner.activate() and self.runner.deactivate() and saw similar results to the one above.

As a sidenote, I've just learned that padding doesn't seem to be necessary for TRT inference, maybe I am wrong and I will need to do some more testing before I can be super confident about it but that would be great.

Thank you.

@NevermindNilas
Copy link
Author

Hi,

    def handleModel(self):
        """
        Load the desired model
        """
        self.isCudaAvailable = torch.cuda.is_available() # Check if CUDA is available
        self.modelPath = r"G:\TheAnimeScripter\src\weights\superultracompact-directml\2x_AnimeJaNai_HD_V3Sharp1_SuperUltraCompact_25k-fp16-sim.onnx" # Path to the model file

        self.device = torch.device("cuda" if self.isCudaAvailable else "cpu") # Set the device to CUDA if available, else CPU
        if self.isCudaAvailable:
            torch.backends.cudnn.enabled = True
            torch.backends.cudnn.benchmark = True
            if self.half:
                torch.set_default_dtype(torch.float16)

        profiles = [
            # The low-latency case. For best performance, min == opt == max.
            Profile().add(
                "input",
                min=(1, 3, self.height, self.width),
                opt=(1, 3, self.height, self.width),
                max=(1, 3, self.height, self.width),
            ),
        ]
        self.engine = engine_from_network(
            network_from_onnx_path(self.modelPath),
            config=CreateConfig(fp16=True, profiles=profiles),
        )

        self.runner = TrtRunner(self.engine) # Create a TensorRT context
        self.runner.activate() # Activate the context
    
    @torch.inference_mode()
    def run(self, frame: np.ndarray) -> np.ndarray:
        frame = (
            torch.from_numpy(frame)
            .permute(2, 0, 1)
            .unsqueeze(0)
            .float()
            .mul_(1 / 255)
        ) # norm ops from np.uint8 to torch.float32

        frame = frame.half() if self.half and self.isCudaAvailable else frame
        output = self.runner.infer({"input": frame}, check_inputs=False])["output"]

        return (
            output.squeeze(0)
            .permute(1, 2, 0)
            .mul_(255)
            .clamp(1, 255)
            .byte()
            .cpu()
            .numpy()
        )

So I did more testing with the code above and let it ran for a good 30 seconds.

VRAM Usage:
image

I was waiting to see what error it would throw out and it seems it still tries to allocate more space.

[E] 2: [executionContext.cpp::nvinfer1::rt::invokeReallocateOutput::207] Error Code 2: Internal Error (IOutputAllocator returned nullptr for allocation request of 49766400 bytes.)
[!] `execute_async_v3()` failed. Please see the logging output above for details.
frame=  811 fps= 15 q=-0.0 Lsize=N/A time=00:00:33.82 bitrate=N/A speed=0.612x
[W] trt-runner-N0-05/05/24-22:09:46     | Was activated but never deactivated. This could cause a memory leak!

This of course is not an issue with the with TrtRuner(self.engine) as runner case but as mentioned earlier activating and deactivating lowers the performance significantly.

I have also played around with the allocation_strategy: str = None, from TRTRunner() and set them to all available options
"static", "profile" & "runtime" but with no success whatsoever.

@pranavm-nvidia
Copy link
Collaborator

Can you try the latest build from the release branch?

@NevermindNilas
Copy link
Author

NevermindNilas commented May 6, 2024

Running polygraphy v0.49.10 with install.ps1 command and the same code as sent earlier. ( only .activate(), never .deactivate() )

image

Terminal output:

[I] Configuring with profiles:[
        Profile 0:
            {input [min=(1, 3, 1080, 1920), opt=(1, 3, 1080, 1920), max=(1, 3, 1080, 1920)]}
    ]
[W] profileSharing0806 is on by default in TensorRT 10.0. This flag is deprecated and has no effect.
[I] Building engine with configuration:
    Flags                  | [FP16]
    Engine Capability      | EngineCapability.STANDARD
    Memory Pools           | [WORKSPACE: 24575.50 MiB, TACTIC_DRAM: 24575.50 MiB, TACTIC_SHARED_MEMORY: 1024.00 MiB]
    Tactic Sources         | [EDGE_MASK_CONVOLUTIONS, JIT_CONVOLUTIONS]
    Profiling Verbosity    | ProfilingVerbosity.DETAILED
    Preview Features       | [PROFILE_SHARING_0806]
[I] Finished engine building in 30.102 seconds
frame=  240 fps= 34 q=-0.0 Lsize=N/A time=00:00:10.01 bitrate=N/A speed=1.42x    
[W] trt-runner-N0-05/06/24-20:01:45     | Was activated but never deactivated. This could cause a memory leak!

This seems to work just fine with no OOM issues anymore from a quick 2 run session with 2 different models, except for slightly slower performance than expected but this could be attributed to my own unoptimized code.

Thank you.

As an additional question, are there any flags that should help to further increase the performance and / or build time of the engine?

Also, is there a way to surpress the unwanted warnings: [W] trt-runner-N0-05/06/24-20:01:45 | Was activated but never deactivated. This could cause a memory leak! & [W] profileSharing0806 is on by default in TensorRT 10.0. This flag is deprecated and has no effect.

Thank you.

@pranavm-nvidia
Copy link
Collaborator

pranavm-nvidia commented May 7, 2024

See this section for information on reducing build time. The options mentioned there should all be available via CreateConfig.

Regarding the warnings, you should call self.runner.deactivate() at some point before the program exits (you can use the atexit module for this). You should be able omit the PROFILE_SHARING_0806 flag by passing preview_features=[] to CreateConfig.

@NevermindNilas
Copy link
Author

Thank you very much, I will look into it.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants