-
Notifications
You must be signed in to change notification settings - Fork 2.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Polygraphy] VRAM Overflow Without Context Manager or 3X Performance Penalty With Context Manager #3836
Comments
Which version of Polygraphy are you using? |
Hello, I am currently using the latest available through pip that being 0.49.x ( sorry I am on my phone so I don't know the exact version past that ) In my particular use case, I input a torch tensor of 1,3,1088,1920 ( added padding from 1080 -> 1088 ) and expect a torch tensor of 1,3,2176, 3840 ( which i then slice from 2176 -> 2160 ) I have only tried building it with a static profile of min opt and max being the same resolutions as the input 1,3,1088,1920. I will do a bit more testing and come back to this with my full implementation a bit later. Thank you. |
Hi, def handleModel(self):
"""
Load the desired model
"""
self.isCudaAvailable = torch.cuda.is_available() # Check if CUDA is available
self.modelPath = r"G:\TheAnimeScripter\src\weights\superultracompact-directml\2x_AnimeJaNai_HD_V3Sharp1_SuperUltraCompact_25k-fp16-sim.onnx" # Path to the model file
self.device = torch.device("cuda" if self.isCudaAvailable else "cpu") # Set the device to CUDA if available, else CPU
if self.isCudaAvailable:
torch.backends.cudnn.enabled = True
torch.backends.cudnn.benchmark = True
if self.half:
torch.set_default_dtype(torch.float16)
profiles = [
# The low-latency case. For best performance, min == opt == max.
Profile().add(
"input",
min=(1, 3, self.height, self.width),
opt=(1, 3, self.height, self.width),
max=(1, 3, self.height, self.width),
),
]
self.engine = engine_from_network(
network_from_onnx_path(self.modelPath),
config=CreateConfig(fp16=True, profiles=profiles),
)
self.runner = TrtRunner(self.engine) # Create a TensorRT context
self.runner.activate() # Activate the context
@torch.inference_mode()
def run(self, frame: np.ndarray) -> np.ndarray:
frame = (
torch.from_numpy(frame)
.permute(2, 0, 1)
.unsqueeze(0)
.float()
.mul_(1 / 255)
) # norm ops from np.uint8 to torch.float32
frame = frame.half() if self.half and self.isCudaAvailable else frame
output = self.runner.infer({"input": frame}, check_inputs=False])["output"]
return (
output.squeeze(0)
.permute(1, 2, 0)
.mul_(255)
.clamp(1, 255)
.byte()
.cpu()
.numpy()
) So I did more testing with the code above and let it ran for a good 30 seconds. I was waiting to see what error it would throw out and it seems it still tries to allocate more space. [E] 2: [executionContext.cpp::nvinfer1::rt::invokeReallocateOutput::207] Error Code 2: Internal Error (IOutputAllocator returned nullptr for allocation request of 49766400 bytes.)
[!] `execute_async_v3()` failed. Please see the logging output above for details.
frame= 811 fps= 15 q=-0.0 Lsize=N/A time=00:00:33.82 bitrate=N/A speed=0.612x
[W] trt-runner-N0-05/05/24-22:09:46 | Was activated but never deactivated. This could cause a memory leak! This of course is not an issue with the I have also played around with the |
Can you try the latest build from the release branch? |
Running polygraphy v0.49.10 with install.ps1 command and the same code as sent earlier. ( only .activate(), never .deactivate() ) Terminal output: [I] Configuring with profiles:[
Profile 0:
{input [min=(1, 3, 1080, 1920), opt=(1, 3, 1080, 1920), max=(1, 3, 1080, 1920)]}
]
[W] profileSharing0806 is on by default in TensorRT 10.0. This flag is deprecated and has no effect.
[I] Building engine with configuration:
Flags | [FP16]
Engine Capability | EngineCapability.STANDARD
Memory Pools | [WORKSPACE: 24575.50 MiB, TACTIC_DRAM: 24575.50 MiB, TACTIC_SHARED_MEMORY: 1024.00 MiB]
Tactic Sources | [EDGE_MASK_CONVOLUTIONS, JIT_CONVOLUTIONS]
Profiling Verbosity | ProfilingVerbosity.DETAILED
Preview Features | [PROFILE_SHARING_0806]
[I] Finished engine building in 30.102 seconds
frame= 240 fps= 34 q=-0.0 Lsize=N/A time=00:00:10.01 bitrate=N/A speed=1.42x
[W] trt-runner-N0-05/06/24-20:01:45 | Was activated but never deactivated. This could cause a memory leak! This seems to work just fine with no OOM issues anymore from a quick 2 run session with 2 different models, except for slightly slower performance than expected but this could be attributed to my own unoptimized code. Thank you. As an additional question, are there any flags that should help to further increase the performance and / or build time of the engine? Also, is there a way to surpress the unwanted warnings: Thank you. |
See this section for information on reducing build time. The options mentioned there should all be available via Regarding the warnings, you should call |
Thank you very much, I will look into it. |
Description
Issue stands as follow, as far as my understanding gets me, you have to .activate() and .deactivate() the generated runner class through:
Problem stands, with the above ^ implementation, the performance penalty of constantly activating and deactivating the engine results in similar if not lower performance than just using pytorch-cuda for example.
One solution to bypass the performance penalty would be just doing:
And then 'never' deactivating it, but of course as the docs indicate this can and will cause immense VRAM Overflows. I assume this is caused because of the engine context just allocating a new memory block for each inference and holding on to the last one until the script ends.
My question is, is there a best of both worlds where you both bypass the performance penalty and the VRAM Overflow.
Environment
TensorRT Version: 10.x ( latest available through pip )
NVIDIA GPU: Rtx 3090
NVIDIA Driver Version: 552.22
CUDA Version: Cuda compilation tools, release 12.4, V12.4.131
CUDNN Version: Honestly, no idea.
Operating System: Windows 11
Python Version (if applicable): 3.11.9
PyTorch Version (if applicable): 2.2.2 + cu12.1
Baremetal or Container (if so, version): Baremetal
Relevant Files
Model link:
https://github.com/NevermindNilas/TAS-Modes-Host/releases/download/main/2x_ModernSpanimationV1_fp16_op17.onnx
Steps To Reproduce
Most of it can be found above
Commands or scripts:
Have you tried the latest release?: Yes
Can this model run on other frameworks? For example run ONNX model with ONNXRuntime (
polygraphy run <model.onnx> --onnxrt
): YesThe text was updated successfully, but these errors were encountered: