Skip to content

Commit

Permalink
feat: engine caching
Browse files Browse the repository at this point in the history
  • Loading branch information
zewenli98 committed Jul 10, 2024
1 parent feb4d84 commit 4e69f5d
Show file tree
Hide file tree
Showing 7 changed files with 365 additions and 12 deletions.
145 changes: 145 additions & 0 deletions examples/dynamo/engine_caching_example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
import time

import numpy as np
import torch
import torch_tensorrt as torch_trt
import torchvision.models as models
from torch.fx.experimental.proxy_tensor import maybe_disable_fake_tensor_mode

np.random.seed(0)
torch.manual_seed(0)
size = (100, 3, 224, 224)
inputs = [torch.rand(size).to("cuda")]

model = models.resnet18(pretrained=True).eval().to("cuda")
exp_program = torch.export.export(model, tuple(inputs))
enabled_precisions = {torch.float}
debug = False
workspace_size = 20 << 30
min_block_size = 0
use_python_runtime = False
torch_executed_ops = {}


def dynamo_path():
############### warmup ###############
inputs = [torch.rand(size).to("cuda")]
t1 = time.time()
trt_gm = torch_trt.dynamo.compile(
exp_program,
tuple(inputs),
use_python_runtime=use_python_runtime,
enabled_precisions=enabled_precisions,
debug=debug,
min_block_size=min_block_size,
torch_executed_ops=torch_executed_ops,
make_refitable=True,
ignore_engine_cache=True,
) # Output is a torch.fx.GraphModule
t2 = time.time()

############### compile for the first time ###############
inputs = [torch.rand(size).to("cuda")]
t3 = time.time()
trt_gm1 = torch_trt.dynamo.compile(
exp_program,
tuple(inputs),
use_python_runtime=use_python_runtime,
enabled_precisions=enabled_precisions,
debug=debug,
min_block_size=min_block_size,
torch_executed_ops=torch_executed_ops,
make_refitable=True,
ignore_engine_cache=False,
) # Output is a torch.fx.GraphModule
t4 = time.time()
# Check the output
outputs = trt_gm1(*inputs)
print("----------> 1st output:", outputs)

############### compile for the second time ###############
inputs = [torch.rand(size).to("cuda")]
t5 = time.time()
trt_gm2 = torch_trt.dynamo.compile(
exp_program,
tuple(inputs),
use_python_runtime=use_python_runtime,
enabled_precisions=enabled_precisions,
debug=debug,
min_block_size=min_block_size,
torch_executed_ops=torch_executed_ops,
make_refitable=True,
ignore_engine_cache=False,
) # Output is a torch.fx.GraphModule
t6 = time.time()
# Check the output
outputs = trt_gm2(*inputs)
print("----------> 2nd output:", outputs)

print("----------> warmup compilation time:", t2 - t1, "seconds")
print("----------> 1st compilation time:", t4 - t3, "seconds")
print("----------> 2nd compilation time:", t6 - t5, "seconds")


def compile_path():
inputs = [torch.rand(size).to("cuda")]
model = models.resnet18(pretrained=True).eval().to("cuda")
t1 = time.time()
model = torch.compile(
model,
backend="tensorrt",
options={
"use_python_runtime": use_python_runtime,
"enabled_precisions": enabled_precisions,
"debug": debug,
"min_block_size": min_block_size,
"torch_executed_ops": torch_executed_ops,
"make_refitable": True,
"ignore_engine_cache": True,
},
)
t2 = time.time()
print("---------->", model(*inputs))

t3 = time.time()
model1 = torch.compile(
model,
backend="tensorrt",
options={
"use_python_runtime": use_python_runtime,
"enabled_precisions": enabled_precisions,
"debug": debug,
"min_block_size": min_block_size,
"torch_executed_ops": torch_executed_ops,
"make_refitable": True,
"ignore_engine_cache": False,
},
)
t4 = time.time()
print("----------> 1st output:", model1(*inputs))

t5 = time.time()
model2 = torch.compile(
model,
backend="tensorrt",
options={
"use_python_runtime": use_python_runtime,
"enabled_precisions": enabled_precisions,
"debug": debug,
"min_block_size": min_block_size,
"torch_executed_ops": torch_executed_ops,
"make_refitable": True,
"ignore_engine_cache": False,
},
)
t6 = time.time()
print("----------> 2nd output:", model2(*inputs))

print("----------> warmup compilation time:", t2 - t1, "seconds")
print("----------> 1st compilation time:", t4 - t3, "seconds")
print("----------> 2nd compilation time:", t6 - t5, "seconds")


if __name__ == "__main__":
dynamo_path()
compile_path()
9 changes: 9 additions & 0 deletions py/torch_tensorrt/dynamo/_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,9 @@ def compile(
dryrun: bool = _defaults.DRYRUN,
hardware_compatible: bool = _defaults.HARDWARE_COMPATIBLE,
timing_cache_path: str = _defaults.TIMING_CACHE_PATH,
ignore_engine_cache: bool = _defaults.IGNORE_ENGINE_CACHE,
engine_cache_dir: str = _defaults.ENGINE_CACHE_DIR,
engine_cache_size: int = _defaults.ENGINE_CACHE_SIZE,
**kwargs: Any,
) -> torch.fx.GraphModule:
"""Compile an ExportedProgram module for NVIDIA GPUs using TensorRT
Expand Down Expand Up @@ -139,6 +142,9 @@ def compile(
dryrun (bool): Toggle for "Dryrun" mode, running everything except conversion to TRT and logging outputs
hardware_compatible (bool): Build the TensorRT engines compatible with GPU architectures other than that of the GPU on which the engine was built (currently works for NVIDIA Ampere and newer)
timing_cache_path (str): Path to the timing cache if it exists (or) where it will be saved after compilation
ignore_engine_cache (bool): Whether to ignore the cached TRT engines and recompile the module
engine_cache_dir (str): Directory to store the cached TRT engines
engine_cache_size (int): Maximum hard-disk space to use for the engine cache
**kwargs: Any,
Returns:
torch.fx.GraphModule: Compiled FX Module, when run it will execute via TensorRT
Expand Down Expand Up @@ -234,6 +240,9 @@ def compile(
"dryrun": dryrun,
"hardware_compatible": hardware_compatible,
"timing_cache_path": timing_cache_path,
"ignore_engine_cache": ignore_engine_cache,
"engine_cache_dir": engine_cache_dir,
"engine_cache_size": engine_cache_size,
}

settings = CompilationSettings(**compilation_options)
Expand Down
3 changes: 3 additions & 0 deletions py/torch_tensorrt/dynamo/_defaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,9 @@
HARDWARE_COMPATIBLE = False
SUPPORTED_KERNEL_PRECISIONS = {dtype.f32, dtype.f16, dtype.bf16, dtype.i8, dtype.f8}
TIMING_CACHE_PATH = os.path.join(tempfile.gettempdir(), "timing_cache.bin")
IGNORE_ENGINE_CACHE = False
ENGINE_CACHE_DIR = os.path.join(tempfile.gettempdir(), "torch_tensorrt_engine_cache")
ENGINE_CACHE_SIZE = 1 << 30


def default_device() -> Device:
Expand Down
152 changes: 152 additions & 0 deletions py/torch_tensorrt/dynamo/_engine_caching.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,152 @@
import ast
import copy
import logging
import os
from abc import ABC, abstractmethod
from typing import List, Optional, Tuple, cast

import torch
from torch._inductor.codecache import FxGraphCachePickler
from torch.fx.experimental.proxy_tensor import maybe_disable_fake_tensor_mode
from torch_tensorrt.dynamo._defaults import ENGINE_CACHE_DIR, ENGINE_CACHE_SIZE

_LOGGER: logging.Logger = logging.getLogger(__name__)


class BaseEngineCache(ABC):

@staticmethod
def get_hash(gm: torch.fx.GraphModule) -> str:
"""Get the hash value of the GraphModule
Args:
gm (torch.fx.GraphModule): GraphModule to hash
Returns:
str: hash value of the GraphModule
"""
# parameters are set to 0
with maybe_disable_fake_tensor_mode():
new_gm = copy.deepcopy(gm)
for name, param in new_gm.named_parameters():
param.data.zero_()

hash_val = cast(str, FxGraphCachePickler.get_hash(gm))

return hash_val

@abstractmethod
def save(
self,
hash: str,
serialized_engine: bytes,
input_names: List[str],
output_names: List[str],
) -> None:
"""Save the serialized engine to hard disk
Args:
hash (str): hash value of the GraphModule
serialized_engine (bytes): serialized TRT engine
input_names (List[str]): input names of TRT engine
output_names (List[str]): output names of TRT engine
Returns:
None
"""
pass

@abstractmethod
def load(self, hash: str) -> Tuple[Optional[bytes], List[str], List[str]]:
"""Load the serialized engine from hard disk
Args:
hash (str): hash value of the GraphModule
Returns:
Sequence[Optional[bytes], List[str], List[str]]: serialized TRT engine, input names of TRT Engine, output names of TRT Engine
"""
pass

@abstractmethod
def clear_cache(self, size: int) -> None:
"""Clear the cache to make sure at least `size` bytes are available
Args:
size (int): the needed size
"""
pass


class EngineCache(BaseEngineCache):

def __init__(
self,
engine_cache_size: int = ENGINE_CACHE_SIZE,
engine_cache_dir: str = ENGINE_CACHE_DIR,
) -> None:
self.total_engine_cache_size = engine_cache_size
self.available_engine_cache_size = engine_cache_size
self.engine_cache_dir = engine_cache_dir

def has_available_cache_size(self, serialized_engine: bytes) -> bool:
"""Check if the cache has available space for saving the serialized engine
Args:
serialized_engine (bytes): serialized TRT engine
Returns:
bool: whether the cache has available size for the serialized engine
"""
return len(serialized_engine) <= self.available_engine_cache_size

def clear_cache(self, size: int) -> None:

def LRU() -> None:
pass

pass

def save(
self,
hash: str,
serialized_engine: bytes,
input_names: List[str],
output_names: List[str],
) -> None:
serialized_engine_size = len(serialized_engine)
if serialized_engine_size <= self.total_engine_cache_size:
_LOGGER.warning(
f"The serialized engine cannot be saved because the size of the engine {serialized_engine_size} is larger than the total cache size {self.total_engine_cache_size}."
)
return

if not self.has_available_cache_size(serialized_engine):
self.clear_cache(serialized_engine_size)

path = os.path.join(
self.engine_cache_dir, f"{hash}/engine_{input_names}_{output_names}.trt"
)
os.makedirs(os.path.dirname(path), exist_ok=True)
with open(path, "wb") as f:
f.write(serialized_engine)
_LOGGER.info(f"A TRT engine was cached to {path}")

def load(self, hash: str) -> Tuple[Optional[bytes], List[str], List[str]]:
directory = os.path.join(self.engine_cache_dir, hash)
if os.path.exists(directory):
engine_list = os.listdir(directory)
assert (
len(engine_list) == 1
), f"There are more than one engine {engine_list} under {directory}."
path = os.path.join(directory, engine_list[0])
input_names_str, output_names_str = (
engine_list[0].split(".")[0].split("_")[1:]
)
input_names = ast.literal_eval(input_names_str)
output_names = ast.literal_eval(output_names_str)
with open(path, "rb") as f:
serialized_engine = f.read()
return serialized_engine, input_names, output_names
else:
return None, [], []
9 changes: 9 additions & 0 deletions py/torch_tensorrt/dynamo/_settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,11 @@
DRYRUN,
ENABLE_EXPERIMENTAL_DECOMPOSITIONS,
ENABLED_PRECISIONS,
ENGINE_CACHE_DIR,
ENGINE_CACHE_SIZE,
ENGINE_CAPABILITY,
HARDWARE_COMPATIBLE,
IGNORE_ENGINE_CACHE,
MAKE_REFITABLE,
MAX_AUX_STREAMS,
MIN_BLOCK_SIZE,
Expand Down Expand Up @@ -73,6 +76,9 @@ class CompilationSettings:
ouptut to a file if a string path is specified
hardware_compatible (bool): Build the TensorRT engines compatible with GPU architectures other than that of the GPU on which the engine was built (currently works for NVIDIA Ampere and newer)
timing_cache_path (str): Path to the timing cache if it exists (or) where it will be saved after compilation
ignore_engine_cache (bool): Whether to ignore the cached TRT engines and recompile the module
engine_cache_dir (str): Directory to store the cached TRT engines
engine_cache_size (int): Maximum hard-disk space to use for the engine cache
"""

enabled_precisions: Set[dtype] = field(default_factory=lambda: ENABLED_PRECISIONS)
Expand Down Expand Up @@ -104,3 +110,6 @@ class CompilationSettings:
dryrun: Union[bool, str] = DRYRUN
hardware_compatible: bool = HARDWARE_COMPATIBLE
timing_cache_path: str = TIMING_CACHE_PATH
ignore_engine_cache: bool = IGNORE_ENGINE_CACHE
engine_cache_dir: str = ENGINE_CACHE_DIR
engine_cache_size: int = ENGINE_CACHE_SIZE
22 changes: 10 additions & 12 deletions py/torch_tensorrt/dynamo/backend/backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,21 +96,19 @@ def _pretraced_backend(
),
)

logger.debug("Post-AOT Autograd graph:\n" + str(gm.graph))
logger.debug("Post-AOT Autograd graph:\n" + str(gm.graph))

gm = post_lowering(gm, sample_inputs)
gm = post_lowering(gm, sample_inputs)

logger.debug("Lowered Input graph:\n " + str(gm.graph))
logger.debug("Lowered Input graph:\n " + str(gm.graph))

torchtrt_inputs = prepare_inputs(
torch_inputs, disable_memory_format_check=True
)
trt_compiled = compile_module(
gm,
torchtrt_inputs,
settings=settings,
)
return trt_compiled
torchtrt_inputs = prepare_inputs(torch_inputs, disable_memory_format_check=True)
trt_compiled = compile_module(
gm,
torchtrt_inputs,
settings=settings,
)
return trt_compiled
except (AssertionError, RuntimeError):
if not settings.pass_through_build_failures:
logger.warning(
Expand Down
Loading

0 comments on commit 4e69f5d

Please sign in to comment.