Skip to content

Commit

Permalink
fix issues from comments, add more unit tests
Browse files Browse the repository at this point in the history
  • Loading branch information
zewenli98 committed Aug 29, 2024
1 parent 24f5c2c commit 9723376
Show file tree
Hide file tree
Showing 7 changed files with 223 additions and 41 deletions.
24 changes: 16 additions & 8 deletions py/torch_tensorrt/dynamo/_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,8 +86,8 @@ def compile(
lazy_engine_init: bool = _defaults.LAZY_ENGINE_INIT,
cache_built_engines: bool = _defaults.CACHE_BUILT_ENGINES,
reuse_cached_engines: bool = _defaults.REUSE_CACHED_ENGINES,
engine_cache_dir: str = _defaults.ENGINE_CACHE_DIR,
engine_cache_size: int = _defaults.ENGINE_CACHE_SIZE,
engine_cache_dir: Optional[str] = _defaults.ENGINE_CACHE_DIR,
engine_cache_size: Optional[int] = _defaults.ENGINE_CACHE_SIZE,
custom_engine_cache: Optional[BaseEngineCache] = _defaults.CUSTOM_ENGINE_CACHE,
**kwargs: Any,
) -> torch.fx.GraphModule:
Expand Down Expand Up @@ -156,8 +156,8 @@ def compile(
lazy_engine_init (bool): Defer setting up engines until the compilation of all engines is complete. Can allow larger models with multiple graph breaks to compile but can lead to oversubscription of GPU memory at runtime.
cache_built_engines (bool): Whether to save the compiled TRT engines to storage
reuse_cached_engines (bool): Whether to load the compiled TRT engines from storage
engine_cache_dir (str): Directory to store the cached TRT engines
engine_cache_size (int): Maximum hard-disk space to use for the engine cache
engine_cache_dir (Optional[str]): Directory to store the cached TRT engines
engine_cache_size (Optional[int]): Maximum hard-disk space (bytes) to use for the engine cache, default is 1GB. If the cache exceeds this size, the oldest engines will be removed by default
custom_engine_cache (Optional[BaseEngineCache]): Engine cache instance to use for saving and loading engines. Users can provide their own engine cache by inheriting from BaseEngineCache. If used, engine_cache_dir and engine_cache_size will be ignored.
**kwargs: Any,
Returns:
Expand Down Expand Up @@ -235,12 +235,16 @@ def compile(
gm = post_lowering(gm)
logger.debug("Lowered Input graph: " + str(gm.graph))

engine_cache = None
if cache_built_engines or reuse_cached_engines:
assert (
make_refitable
), "Engine caching requires make_refitable to be set to True"
if custom_engine_cache is None:
custom_engine_cache = DiskEngineCache(engine_cache_dir, engine_cache_size)
engine_cache = (
custom_engine_cache
if custom_engine_cache is not None
else DiskEngineCache(engine_cache_dir, engine_cache_size)
)

compilation_options = {
"enabled_precisions": (
Expand Down Expand Up @@ -277,12 +281,13 @@ def compile(
"lazy_engine_init": lazy_engine_init,
"cache_built_engines": cache_built_engines,
"reuse_cached_engines": reuse_cached_engines,
"custom_engine_cache": custom_engine_cache,
}

settings = CompilationSettings(**compilation_options)
logger.info("Compilation Settings: %s\n", settings)
trt_gm = compile_module(gm, trt_arg_inputs, trt_kwarg_inputs, settings)
trt_gm = compile_module(
gm, trt_arg_inputs, trt_kwarg_inputs, settings, engine_cache
)
return trt_gm


Expand All @@ -291,6 +296,7 @@ def compile_module(
sample_arg_inputs: Sequence[Input],
sample_kwarg_inputs: Optional[dict[Any, Any]] = None,
settings: CompilationSettings = CompilationSettings(),
engine_cache: Optional[BaseEngineCache] = None,
) -> torch.fx.GraphModule:
"""Compile a traced FX module
Expand All @@ -301,6 +307,7 @@ def compile_module(
arg_inputs: Inputs to the module
kwarg_inputs: kwargs to the module
settings: Compilation settings
engine_cache: Engine cache instance to store/load compiled engines
Returns:
Compiled FX GraphModule
"""
Expand Down Expand Up @@ -480,6 +487,7 @@ def contains_metadata(gm: torch.fx.GraphModule) -> bool:
submodule_inputs,
settings=settings,
name=name,
engine_cache=engine_cache,
)

trt_modules[name] = trt_module
Expand Down
4 changes: 0 additions & 4 deletions py/torch_tensorrt/dynamo/_settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
from torch_tensorrt.dynamo._defaults import (
ASSUME_DYNAMIC_SHAPE_SUPPORT,
CACHE_BUILT_ENGINES,
CUSTOM_ENGINE_CACHE,
DEBUG,
DISABLE_TF32,
DLA_GLOBAL_DRAM_SIZE,
Expand Down Expand Up @@ -36,7 +35,6 @@
WORKSPACE_SIZE,
default_device,
)
from torch_tensorrt.dynamo._engine_caching import BaseEngineCache


@dataclass
Expand Down Expand Up @@ -80,7 +78,6 @@ class CompilationSettings:
timing_cache_path (str): Path to the timing cache if it exists (or) where it will be saved after compilation
cache_built_engines (bool): Whether to save the compiled TRT engines to storage
reuse_cached_engines (bool): Whether to load the compiled TRT engines from storage
custom_engine_cache (Optional[BaseEngineCache]): Engine cache instance to use for saving and loading engines. Users can provide their own engine cache by inheriting from BaseEngineCache
"""

enabled_precisions: Set[dtype] = field(default_factory=lambda: ENABLED_PRECISIONS)
Expand Down Expand Up @@ -115,4 +112,3 @@ class CompilationSettings:
lazy_engine_init: bool = LAZY_ENGINE_INIT
cache_built_engines: bool = CACHE_BUILT_ENGINES
reuse_cached_engines: bool = REUSE_CACHED_ENGINES
custom_engine_cache: Optional[BaseEngineCache] = CUSTOM_ENGINE_CACHE
7 changes: 5 additions & 2 deletions py/torch_tensorrt/dynamo/backend/backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,21 +48,23 @@ def torch_tensorrt_backend(
def aot_torch_tensorrt_aten_backend(
gm: torch.fx.GraphModule, sample_inputs: Sequence[Any], **kwargs: Any
) -> torch.nn.Module:
settings = parse_dynamo_kwargs(kwargs)
return _pretraced_backend(gm, sample_inputs, settings)
settings, engine_cache = parse_dynamo_kwargs(kwargs)
return _pretraced_backend(gm, sample_inputs, settings, engine_cache)


def _pretraced_backend(
gm: torch.fx.GraphModule,
sample_inputs: Sequence[Any],
settings: CompilationSettings = CompilationSettings(),
engine_cache: Any = None,
) -> torch.fx.GraphModule | Callable[..., Any]:
"""Helper function to manage translation of traced FX module to TRT engines
Args:
module: FX GraphModule to convert
inputs: Inputs to the module
settings: Compilation settings
engine_cache: Engine cache instance
Returns:
Compiled FX GraphModule
"""
Expand Down Expand Up @@ -109,6 +111,7 @@ def _pretraced_backend(
gm,
torchtrt_inputs,
settings=settings,
engine_cache=engine_cache,
)
return trt_compiled
except (AssertionError, RuntimeError):
Expand Down
25 changes: 15 additions & 10 deletions py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from torch_tensorrt._enums import dtype
from torch_tensorrt._Input import Input
from torch_tensorrt.dynamo import _defaults
from torch_tensorrt.dynamo._engine_caching import BaseEngineCache
from torch_tensorrt.dynamo._settings import CompilationSettings
from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext
from torch_tensorrt.dynamo.conversion._ConverterRegistry import (
Expand Down Expand Up @@ -71,6 +72,7 @@ def __init__(
logger_level: trt.ILogger.Severity = trt.ILogger.Severity.WARNING,
output_dtypes: Optional[Sequence[dtype]] = None,
compilation_settings: CompilationSettings = CompilationSettings(),
engine_cache: Optional[BaseEngineCache] = None,
):
super().__init__(module)

Expand Down Expand Up @@ -126,6 +128,9 @@ def __init__(
self.const_mapping: Dict[str, Tuple[Sequence[int], str]] = {}
self.weight_name_map: Optional[dict[str, Any]] = None

# Engine cache for storing and reusing TRT engines
self.engine_cache = engine_cache

def validate_conversion(self) -> Set[str]:
missing_converters: Set[str] = set()

Expand Down Expand Up @@ -521,22 +526,22 @@ def run(
Return:
TRTInterpreterResult
"""
if (
self.compilation_settings.custom_engine_cache is not None
): # custom_engine_cache could be None if this function is called from convert_exported_program_to_serialized_trt_engine etc.
# self.engine_cache could be None if:
# 1) engine_cache is not passed in when calling this function like convert_exported_program_to_serialized_trt_engine etc., or
# 2) both cache_built_engines and reuse_cached_engines are False
if self.engine_cache is not None:
if (
self.compilation_settings.cache_built_engines
or self.compilation_settings.reuse_cached_engines
):
engine_cache = self.compilation_settings.custom_engine_cache
hash_val = engine_cache.get_hash(self.module)
hash_val = self.engine_cache.get_hash(self.module)

if self.compilation_settings.reuse_cached_engines:
# query the cached TRT engine
blob = engine_cache.load(hash_val)
blob = self.engine_cache.load(hash_val)
if blob is not None: # hit the cache
serialized_engine, input_names, output_names, weight_name_map = (
engine_cache.unpack(blob)
self.engine_cache.unpack(blob)
)
self._input_names = input_names
self._output_names = output_names
Expand Down Expand Up @@ -605,16 +610,16 @@ def run(
builder_config, self.compilation_settings.timing_cache_path
)
if (
self.compilation_settings.custom_engine_cache is not None
self.engine_cache is not None
and self.compilation_settings.cache_built_engines
):
blob = engine_cache.pack(
blob = self.engine_cache.pack(
serialized_engine,
self._input_names,
self._output_names,
self.weight_name_map,
)
engine_cache.save(hash_val, blob)
self.engine_cache.save(hash_val, blob)

with io.BytesIO() as engine_bytes:
engine_bytes.write(serialized_engine)
Expand Down
10 changes: 9 additions & 1 deletion py/torch_tensorrt/dynamo/conversion/_conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from torch_tensorrt._enums import dtype
from torch_tensorrt._features import ENABLED_FEATURES
from torch_tensorrt._Input import Input
from torch_tensorrt.dynamo._engine_caching import BaseEngineCache
from torch_tensorrt.dynamo._settings import CompilationSettings
from torch_tensorrt.dynamo.conversion._TRTInterpreter import (
TRTInterpreter,
Expand Down Expand Up @@ -70,6 +71,7 @@ def interpret_module_to_result(
settings: CompilationSettings = CompilationSettings(),
arg_inputs: Optional[Sequence[Input]] = None,
kwarg_inputs: Optional[dict[str, Any]] = None,
engine_cache: Optional[BaseEngineCache] = None,
) -> TRTInterpreterResult:
"""Interpret an FX module to a TRTInterpreterResult
Args:
Expand All @@ -79,6 +81,7 @@ def interpret_module_to_result(
arg_inputs: Sequence of Tensors representing inputs to the module.
kwarg_inputs: A dictionary of Tensors representing inputs to the module.
settings: Compilation settings
engine_cache: Engine cache instance
Returns:
TRTInterpreterResult
"""
Expand All @@ -105,6 +108,7 @@ def interpret_module_to_result(
logger_level=(trt.Logger.VERBOSE if settings.debug else trt.Logger.WARNING),
output_dtypes=output_dtypes,
compilation_settings=settings,
engine_cache=engine_cache,
)
interpreter_result = interpreter.run()
return interpreter_result
Expand All @@ -115,17 +119,21 @@ def convert_module(
inputs: Sequence[Input],
settings: CompilationSettings = CompilationSettings(),
name: str = "",
engine_cache: Optional[BaseEngineCache] = None,
) -> PythonTorchTensorRTModule | TorchTensorRTModule:
"""Convert an FX module to a TRT module
Args:
module: FX GraphModule to convert
inputs: Sequence of Tensors representing inputs to the module
settings: Compilation settings
name: TRT engine name
engine_cache: Engine cache instance
Returns:
PythonTorchTensorRTModule or TorchTensorRTModule
"""
interpreter_result = interpret_module_to_result(module, inputs, settings)
interpreter_result = interpret_module_to_result(
module, inputs, settings, engine_cache=engine_cache
)

rt_cls = PythonTorchTensorRTModule

Expand Down
19 changes: 12 additions & 7 deletions py/torch_tensorrt/dynamo/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import logging
from dataclasses import fields, replace
from enum import Enum
from typing import Any, Callable, Dict, Optional, Sequence, Union
from typing import Any, Callable, Dict, Optional, Sequence, Tuple, Union

import numpy as np
import tensorrt as trt
Expand All @@ -12,6 +12,7 @@
from torch_tensorrt._enums import dtype
from torch_tensorrt._Input import Input
from torch_tensorrt.dynamo import _defaults
from torch_tensorrt.dynamo._engine_caching import BaseEngineCache
from torch_tensorrt.dynamo._settings import CompilationSettings

from packaging import version
Expand Down Expand Up @@ -301,7 +302,9 @@ def to_torch_tensorrt_device(
return Device._from(device)


def parse_dynamo_kwargs(kwargs: Any) -> CompilationSettings:
def parse_dynamo_kwargs(
kwargs: Any,
) -> Tuple[CompilationSettings, Optional[BaseEngineCache]]:
"""Parses the kwargs field of a Dynamo backend
Args:
Expand Down Expand Up @@ -360,11 +363,15 @@ def parse_dynamo_kwargs(kwargs: Any) -> CompilationSettings:

# If cache_built_engines and reuse_cached_engines are True but custom_engine_cache is not provided,
# then create a default disk engine cache
engine_cache = None
if kwargs.get("cache_built_engines") or kwargs.get("reuse_cached_engines"):
assert kwargs.get(
"make_refitable"
), "Engine caching requires make_refitable to be set to True"
if settings.custom_engine_cache is None:

if kwargs.get("custom_engine_cache") is not None:
engine_cache = kwargs.get("custom_engine_cache")
else:
from torch_tensorrt.dynamo._engine_caching import DiskEngineCache

engine_cache_dir = kwargs.get(
Expand All @@ -373,13 +380,11 @@ def parse_dynamo_kwargs(kwargs: Any) -> CompilationSettings:
engine_cache_size = kwargs.get(
"engine_cache_size", _defaults.ENGINE_CACHE_SIZE
)
settings.custom_engine_cache = DiskEngineCache(
engine_cache_dir, engine_cache_size
)
engine_cache = DiskEngineCache(engine_cache_dir, engine_cache_size)

logger.info("Compilation Settings: %s\n", settings)

return settings
return settings, engine_cache


def req_torch_version(min_torch_version: str = "2.dev") -> Callable[..., Any]:
Expand Down
Loading

0 comments on commit 9723376

Please sign in to comment.