From 748b4c6f9a698ee446f3de41e64149b16769486c Mon Sep 17 00:00:00 2001 From: Evan Li Date: Tue, 20 Aug 2024 22:12:01 -0700 Subject: [PATCH] refactor --- .../dynamo/engine_caching_bert_example.py | 13 +- examples/dynamo/engine_caching_example.py | 89 +++---- py/torch_tensorrt/dynamo/_compiler.py | 47 +--- py/torch_tensorrt/dynamo/_defaults.py | 9 +- py/torch_tensorrt/dynamo/_engine_caching.py | 245 ++++++++---------- py/torch_tensorrt/dynamo/_settings.py | 24 +- .../dynamo/conversion/_TRTInterpreter.py | 62 +++-- py/torch_tensorrt/dynamo/utils.py | 16 ++ 8 files changed, 227 insertions(+), 278 deletions(-) diff --git a/examples/dynamo/engine_caching_bert_example.py b/examples/dynamo/engine_caching_bert_example.py index 2f133f5e8f..43cfc5f15a 100644 --- a/examples/dynamo/engine_caching_bert_example.py +++ b/examples/dynamo/engine_caching_bert_example.py @@ -29,11 +29,11 @@ def compile_bert(iterations=3): torch._dynamo.reset() if i == 0: - save_engine_cache = False - load_engine_cache = False + cache_built_engines = False + reuse_cached_engines = False else: - save_engine_cache = True - load_engine_cache = True + cache_built_engines = True + reuse_cached_engines = True start.record() compilation_kwargs = { @@ -43,8 +43,9 @@ def compile_bert(iterations=3): "debug": False, "min_block_size": 1, "make_refitable": True, - "save_engine_cache": save_engine_cache, - "load_engine_cache": load_engine_cache, + "cache_built_engines": cache_built_engines, + "reuse_cached_engines": reuse_cached_engines, + "engine_cache_dir": "/tmp/torch_trt_bert_engine_cache", "engine_cache_size": 1 << 30, # 1GB } optimized_model = torch.compile( diff --git a/examples/dynamo/engine_caching_example.py b/examples/dynamo/engine_caching_example.py index 80cf696466..89912e74b0 100644 --- a/examples/dynamo/engine_caching_example.py +++ b/examples/dynamo/engine_caching_example.py @@ -1,7 +1,5 @@ -import ast -import logging import os -from typing import List, Optional, Tuple +from typing import Optional import numpy as np import torch @@ -10,9 +8,6 @@ from torch_tensorrt.dynamo._defaults import TIMING_CACHE_PATH from torch_tensorrt.dynamo._engine_caching import BaseEngineCache -_LOGGER: logging.Logger = logging.getLogger(__name__) - - np.random.seed(0) torch.manual_seed(0) size = (100, 3, 224, 224) @@ -49,11 +44,11 @@ def dynamo_path(iterations=3): inputs = [torch.rand((100 + i, 3, 224, 224)).to("cuda")] remove_timing_cache() # remove timing cache for engine caching messurement if i == 0: - save_engine_cache = False - load_engine_cache = False + cache_built_engines = False + reuse_cached_engines = False else: - save_engine_cache = True - load_engine_cache = True + cache_built_engines = True + reuse_cached_engines = True start.record() trt_gm = torch_trt.dynamo.compile( @@ -64,8 +59,8 @@ def dynamo_path(iterations=3): debug=debug, min_block_size=min_block_size, make_refitable=True, - save_engine_cache=save_engine_cache, - load_engine_cache=load_engine_cache, + cache_built_engines=cache_built_engines, + reuse_cached_engines=reuse_cached_engines, engine_cache_size=1 << 30, # 1GB ) end.record() @@ -79,60 +74,36 @@ def dynamo_path(iterations=3): class MyEngineCache(BaseEngineCache): def __init__( self, - engine_cache_size: int, engine_cache_dir: str, ) -> None: - self.total_engine_cache_size = engine_cache_size - self.available_engine_cache_size = engine_cache_size self.engine_cache_dir = engine_cache_dir def save( self, hash: str, - serialized_engine: bytes, - input_names: List[str], - output_names: List[str], - ) -> bool: + blob: bytes, + prefix: str = "blob", + ): path = os.path.join( self.engine_cache_dir, - f"{hash}/engine--{input_names}--{output_names}.trt", + f"{prefix}_{hash}.bin", ) - try: - os.makedirs(os.path.dirname(path), exist_ok=True) - with open(path, "wb") as f: - f.write(serialized_engine) - except Exception as e: - _LOGGER.warning(f"Failed to save the TRT engine to {path}: {e}") - return False - - _LOGGER.info(f"A TRT engine was cached to {path}") - serialized_engine_size = int(serialized_engine.nbytes) - self.available_engine_cache_size -= serialized_engine_size - return True - - def load(self, hash: str) -> Tuple[Optional[bytes], List[str], List[str]]: - directory = os.path.join(self.engine_cache_dir, hash) - if os.path.exists(directory): - engine_list = os.listdir(directory) - assert ( - len(engine_list) == 1 - ), f"There are more than one engine {engine_list} under {directory}." - path = os.path.join(directory, engine_list[0]) - input_names_str, output_names_str = ( - engine_list[0].split(".trt")[0].split("--")[1:] - ) - input_names = ast.literal_eval(input_names_str) - output_names = ast.literal_eval(output_names_str) + os.makedirs(path, exist_ok=True) + with open(path, "wb") as f: + f.write(blob) + + def load(self, hash: str, prefix: str = "blob") -> Optional[bytes]: + path = os.path.join(self.engine_cache_dir, f"{prefix}_{hash}.bin") + if os.path.exists(path): with open(path, "rb") as f: - serialized_engine = f.read() - return serialized_engine, input_names, output_names - else: - return None, [], [] + blob = f.read() + return blob + return None def compile_path(iterations=3): times = [] - engine_cache = MyEngineCache(200 * (1 << 20), "/tmp/your_dir") + engine_cache = MyEngineCache("/tmp/your_dir") start = torch.cuda.Event(enable_timing=True) end = torch.cuda.Event(enable_timing=True) @@ -147,11 +118,11 @@ def compile_path(iterations=3): torch._dynamo.reset() if i == 0: - save_engine_cache = False - load_engine_cache = False + cache_built_engines = False + reuse_cached_engines = False else: - save_engine_cache = True - load_engine_cache = True + cache_built_engines = True + reuse_cached_engines = True start.record() compiled_model = torch.compile( @@ -163,9 +134,9 @@ def compile_path(iterations=3): "debug": debug, "min_block_size": min_block_size, "make_refitable": True, - "save_engine_cache": save_engine_cache, - "load_engine_cache": load_engine_cache, - "engine_cache_instance": engine_cache, # use custom engine cache + "cache_built_engines": cache_built_engines, + "reuse_cached_engines": reuse_cached_engines, + "custom_engine_cache": engine_cache, # use custom engine cache }, ) compiled_model(*inputs) # trigger the compilation @@ -178,4 +149,4 @@ def compile_path(iterations=3): if __name__ == "__main__": dynamo_path() - compile_path() + # compile_path() diff --git a/py/torch_tensorrt/dynamo/_compiler.py b/py/torch_tensorrt/dynamo/_compiler.py index 5f8408d0b9..5ddd615f90 100644 --- a/py/torch_tensorrt/dynamo/_compiler.py +++ b/py/torch_tensorrt/dynamo/_compiler.py @@ -18,7 +18,7 @@ dryrun_stats_display, parse_non_trt_nodes, ) -from torch_tensorrt.dynamo._engine_caching import BaseEngineCache, EngineCache +from torch_tensorrt.dynamo._engine_caching import BaseEngineCache, DiskEngineCache from torch_tensorrt.dynamo.conversion import ( CompilationSettings, UnsupportedOperatorException, @@ -84,11 +84,11 @@ def compile( hardware_compatible: bool = _defaults.HARDWARE_COMPATIBLE, timing_cache_path: str = _defaults.TIMING_CACHE_PATH, lazy_engine_init: bool = _defaults.LAZY_ENGINE_INIT, - save_engine_cache: bool = _defaults.SAVE_ENGINE_CACHE, - load_engine_cache: bool = _defaults.LOAD_ENGINE_CACHE, + cache_built_engines: bool = _defaults.CACHE_BUILT_ENGINES, + reuse_cached_engines: bool = _defaults.REUSE_CACHED_ENGINES, engine_cache_dir: str = _defaults.ENGINE_CACHE_DIR, engine_cache_size: int = _defaults.ENGINE_CACHE_SIZE, - engine_cache_instance: Optional[BaseEngineCache] = None, + custom_engine_cache: Optional[BaseEngineCache] = _defaults.CUSTOM_ENGINE_CACHE, **kwargs: Any, ) -> torch.fx.GraphModule: """Compile an ExportedProgram module for NVIDIA GPUs using TensorRT @@ -154,11 +154,11 @@ def compile( hardware_compatible (bool): Build the TensorRT engines compatible with GPU architectures other than that of the GPU on which the engine was built (currently works for NVIDIA Ampere and newer) timing_cache_path (str): Path to the timing cache if it exists (or) where it will be saved after compilation lazy_engine_init (bool): Defer setting up engines until the compilation of all engines is complete. Can allow larger models with multiple graph breaks to compile but can lead to oversubscription of GPU memory at runtime. - save_engine_cache (bool): Whether to save the compiled TRT engines to hard disk - load_engine_cache (bool): Whether to load the compiled TRT engines from hard disk + cache_built_engines (bool): Whether to save the compiled TRT engines to storage + reuse_cached_engines (bool): Whether to load the compiled TRT engines from storage engine_cache_dir (str): Directory to store the cached TRT engines engine_cache_size (int): Maximum hard-disk space to use for the engine cache - engine_cache_instance (Optional[BaseEngineCache]): Engine cache instance to use for saving and loading engines. Users can provide their own engine cache by inheriting from BaseEngineCache + custom_engine_cache (Optional[BaseEngineCache]): Engine cache instance to use for saving and loading engines. Users can provide their own engine cache by inheriting from BaseEngineCache. If used, engine_cache_dir and engine_cache_size will be ignored. **kwargs: Any, Returns: torch.fx.GraphModule: Compiled FX Module, when run it will execute via TensorRT @@ -235,10 +235,9 @@ def compile( gm = post_lowering(gm) logger.debug("Lowered Input graph: " + str(gm.graph)) - if engine_cache_instance is None: - engine_cache_instance = EngineCacheInstanceCreator.get_creator( - engine_cache_size, engine_cache_dir - ).engine_cache_instance + if cache_built_engines or reuse_cached_engines: + if custom_engine_cache is None: + custom_engine_cache = DiskEngineCache(engine_cache_dir, engine_cache_size) compilation_options = { "enabled_precisions": ( @@ -273,11 +272,9 @@ def compile( "hardware_compatible": hardware_compatible, "timing_cache_path": timing_cache_path, "lazy_engine_init": lazy_engine_init, - "save_engine_cache": save_engine_cache, - "load_engine_cache": load_engine_cache, - "engine_cache_dir": engine_cache_dir, - "engine_cache_size": engine_cache_size, - "engine_cache_instance": engine_cache_instance, + "cache_built_engines": cache_built_engines, + "reuse_cached_engines": reuse_cached_engines, + "custom_engine_cache": custom_engine_cache, } settings = CompilationSettings(**compilation_options) @@ -724,21 +721,3 @@ def convert_exported_program_to_serialized_trt_engine( serialized_engine: bytes = interpreter_result.serialized_engine return serialized_engine - - -class EngineCacheInstanceCreator: - engine_cache_creator = None - - def __init__(self, engine_cache_size: int, engine_cache_dir: str) -> None: - self.engine_cache_instance = EngineCache( - engine_cache_size=engine_cache_size, - engine_cache_dir=engine_cache_dir, - ) - - @classmethod - def get_creator( - cls, engine_cache_size: int, engine_cache_dir: str - ) -> EngineCacheInstanceCreator: - if cls.engine_cache_creator is None: - cls.engine_cache_creator = cls(engine_cache_size, engine_cache_dir) - return cls.engine_cache_creator diff --git a/py/torch_tensorrt/dynamo/_defaults.py b/py/torch_tensorrt/dynamo/_defaults.py index e90f3f8c2a..83e85cb3c7 100644 --- a/py/torch_tensorrt/dynamo/_defaults.py +++ b/py/torch_tensorrt/dynamo/_defaults.py @@ -4,7 +4,6 @@ import torch from torch_tensorrt._Device import Device from torch_tensorrt._enums import EngineCapability, dtype -from torch_tensorrt.dynamo._engine_caching import EngineCache ENABLED_PRECISIONS = {dtype.f32} DEBUG = False @@ -36,13 +35,11 @@ tempfile.gettempdir(), "torch_tensorrt_engine_cache", "timing_cache.bin" ) LAZY_ENGINE_INIT = False -SAVE_ENGINE_CACHE = True -LOAD_ENGINE_CACHE = True +CACHE_BUILT_ENGINES = True +REUSE_CACHED_ENGINES = True ENGINE_CACHE_DIR = os.path.join(tempfile.gettempdir(), "torch_tensorrt_engine_cache") ENGINE_CACHE_SIZE = 1073741824 -ENGINE_CACHE_INSTANCE = EngineCache( - engine_cache_size=ENGINE_CACHE_SIZE, engine_cache_dir=ENGINE_CACHE_DIR -) +CUSTOM_ENGINE_CACHE = None def default_device() -> Device: diff --git a/py/torch_tensorrt/dynamo/_engine_caching.py b/py/torch_tensorrt/dynamo/_engine_caching.py index f9b6f075eb..01220233ea 100644 --- a/py/torch_tensorrt/dynamo/_engine_caching.py +++ b/py/torch_tensorrt/dynamo/_engine_caching.py @@ -3,7 +3,6 @@ import os import pickle import shutil -import sys from abc import ABC, abstractmethod from typing import Any, Dict, List, Optional, Tuple, cast @@ -44,78 +43,126 @@ def get_hash(gm: torch.fx.GraphModule) -> str: return hash_val - @abstractmethod - def save( - self, - hash: str, + @staticmethod + def pack( serialized_engine: bytes, input_names: List[str], output_names: List[str], - weight_name_map: Optional[Dict[str, Any]] = None, - ) -> bool: - """Save the serialized engine to hard disk + weight_name_map: Optional[Dict[str, Any]], + ) -> bytes: + """Pack serialized engine, input names, output names, and weight map into a single blob Args: - hash (str): hash value of the GraphModule serialized_engine (bytes): serialized TRT engine input_names (List[str]): input names of TRT engine output_names (List[str]): output names of TRT engine weight_name_map (Optional[Dict[str, Any]]): weight name map for refitting Returns: - bool: whether the serialized engine is saved successfully + bytes: packed blob + """ + return pickle.dumps( + { + "serialized_engine": bytes(serialized_engine), + "input_names": input_names, + "output_names": output_names, + "weight_name_map": weight_name_map, + } + ) + + @staticmethod + def unpack( + packed_obj: bytes, + ) -> Tuple[bytes, List[str], List[str], Optional[Dict[str, Any]]]: + """Unpack packed blob into serialized engine, input names, output names, and weight map + + Args: + packed_obj (bytes): packed blob + + Returns: + Tuple[bytes, List[str], List[str], Optional[Dict[str, Any]]]: serialized engine, input names, output names, weight name map + """ + unpacked = pickle.loads(packed_obj) + return ( + unpacked["serialized_engine"], + unpacked["input_names"], + unpacked["output_names"], + unpacked["weight_name_map"], + ) + + @abstractmethod + def save(self, hash: str, blob: bytes, *args: Any, **kwargs: Any) -> None: + """Store blob in cache + + Args: + hash (str): hash value of the GraphModule + blob (bytes): packed blob """ pass @abstractmethod - def load( - self, hash: str - ) -> Tuple[Optional[bytes], List[str], List[str], Optional[Dict[str, Any]]]: - """Load the serialized engine from hard disk + def load(self, hash: str, *args: Any, **kwargs: Any) -> Optional[bytes]: + """Load blob from storage Args: hash (str): hash value of the GraphModule Returns: - Sequence[Optional[bytes], List[str], List[str], Optional[Dict[str, Any]]]: serialized engine, input names, output names, weight name map + Optional[bytes]: blob or None if doesn't hit """ pass -class EngineCache(BaseEngineCache): +class DiskEngineCache(BaseEngineCache): + dir2hash2size_map: Dict[str, Dict[str, int]] = ( + {} + ) # dir2hash2size_map["engine_cache_dir"]["hash"] = size def __init__( self, - engine_cache_size: int, engine_cache_dir: str, + engine_cache_size: int, ) -> None: - self.total_engine_cache_size = engine_cache_size - self.available_engine_cache_size = engine_cache_size + + def get_dir_size(path: str) -> int: + total = 0 + with os.scandir(path) as it: + for entry in it: + if entry.is_file(): + total += entry.stat().st_size + elif entry.is_dir(): + total += get_dir_size(entry.path) + return total + + if not os.path.exists(engine_cache_dir): + os.makedirs(engine_cache_dir, exist_ok=True) self.engine_cache_dir = engine_cache_dir - self.hash2size_map: Dict[str, int] = {} + self.total_engine_cache_size = engine_cache_size + self.available_engine_cache_size = engine_cache_size - get_dir_size( + engine_cache_dir + ) + if engine_cache_dir not in DiskEngineCache.dir2hash2size_map: + DiskEngineCache.dir2hash2size_map[engine_cache_dir] = {} def has_available_cache_size(self, needed_size: int) -> bool: - """Check if the cache has available space for saving the serialized engine + """Check if the cache has available space for saving object Args: - needed_size (int): needed size for erialized TRT engine and/or weight_name_map + needed_size (int): needed size for saving object Returns: - bool: whether the cache has available size for the serialized engine + bool: whether the cache has available size for saving object """ return needed_size <= self.available_engine_cache_size - def clear_cache(self, needed_min_size: int) -> bool: + def clear_cache(self, needed_min_size: int) -> None: """Clear the cache to make sure at least `needed_min_size` bytes are available, if possible Args: needed_min_size (int): the minimum needed size - - Returns: - bool: whether the cache is cleared successfully """ - def LRU() -> bool: + def LRU() -> None: """Clear the Least Recently Used engine in the cache""" # Get the list of engine directories engines_hash_values = os.listdir(self.engine_cache_dir) @@ -132,8 +179,10 @@ def LRU() -> bool: # Remove the entire directory shutil.rmtree(engine_path) # Update the available cache size - self.available_engine_cache_size += self.hash2size_map.pop( - engine_hash, 0 + self.available_engine_cache_size += ( + DiskEngineCache.dir2hash2size_map[self.engine_cache_dir].pop( + engine_hash, 0 + ) ) _LOGGER.info( f"Removed the engine cache at {engine_path}, available cache size: {self.available_engine_cache_size} bytes." @@ -142,127 +191,61 @@ def LRU() -> bool: _LOGGER.warning( f"Failed to clear the engine cache at {engine_path}: {e}" ) - return False - return True - if not os.path.exists(self.engine_cache_dir): - return False - - _LOGGER.info( - f"Total cache size: {self.total_engine_cache_size} bytes; available cache size: {self.available_engine_cache_size} bytes. Clearing the cache to make sure at least {needed_min_size} bytes are available." - ) - return LRU() + if needed_min_size > self.total_engine_cache_size: + _LOGGER.warning( + f"The needed minimum size {needed_min_size} is larger than the total cache size {self.total_engine_cache_size}. Nothing will be cleared." + ) + else: + LRU() def save( self, hash: str, - serialized_engine: bytes, - input_names: List[str], - output_names: List[str], - weight_name_map: Optional[Dict[str, Any]] = None, - ) -> bool: - serialized_engine_size = int(serialized_engine.nbytes) - if weight_name_map is not None: - serialized_engine_size += sum( - sys.getsizeof(v) for v in weight_name_map.values() - ) - if serialized_engine_size > self.total_engine_cache_size: + blob: bytes, + ) -> None: + blob_size = len(blob) + if blob_size > self.total_engine_cache_size: _LOGGER.warning( - f"The serialized engine cannot be saved because the size of the engine {serialized_engine_size} is larger than the total cache size {self.total_engine_cache_size}." + f"The serialized engine cannot be saved because the size {blob_size} is larger than the total cache size {self.total_engine_cache_size}." ) - return False + return - # Check if there is enough available cache size for the serialized engine and/or weight_name_map - if not self.has_available_cache_size(serialized_engine_size): - self.clear_cache(serialized_engine_size) + if not self.has_available_cache_size(blob_size): + self.clear_cache(blob_size) - # Save the serialized engine to the cache directory - if self.has_available_cache_size(serialized_engine_size): - self.hash2size_map[hash] = serialized_engine_size - self.available_engine_cache_size -= serialized_engine_size + if self.has_available_cache_size(blob_size): + DiskEngineCache.dir2hash2size_map[self.engine_cache_dir][hash] = blob_size + self.available_engine_cache_size -= blob_size directory = os.path.join(self.engine_cache_dir, hash) + if not os.path.exists(directory): + os.makedirs(directory, exist_ok=True) - engine_path = os.path.join( - directory, - "engine.trt", - ) - io_names_path = os.path.join( + blob_path = os.path.join( directory, - "io_names.pkl", + "blob.bin", ) try: - os.makedirs(os.path.dirname(engine_path), exist_ok=True) - with open(engine_path, "wb") as f: - f.write(serialized_engine) - os.makedirs(os.path.dirname(io_names_path), exist_ok=True) - with open(io_names_path, "wb") as f: - pickle.dump( - {"input_names": input_names, "output_names": output_names}, f - ) - _LOGGER.info(f"The TRT engine was saved to {engine_path}") + with open(blob_path, "wb") as f: + f.write(blob) + _LOGGER.info(f"The blob was saved to {blob_path}") except Exception as e: - del self.hash2size_map[hash] - self.available_engine_cache_size += serialized_engine_size + del DiskEngineCache.dir2hash2size_map[self.engine_cache_dir][hash] + self.available_engine_cache_size += blob_size shutil.rmtree(directory) - _LOGGER.warning(f"Failed to save the TRT engine to {engine_path}: {e}") - return False - - if weight_name_map is not None: - weight_name_map_path = os.path.join( - directory, - "weight_name_map.pkl", - ) - try: - os.makedirs(os.path.dirname(weight_name_map_path), exist_ok=True) - with open(weight_name_map_path, "wb") as f: - pickle.dump(weight_name_map, f) - _LOGGER.info( - f"The weight_name_map was saved to {weight_name_map_path}" - ) - except Exception as e: - del self.hash2size_map[hash] - self.available_engine_cache_size += serialized_engine_size - shutil.rmtree(directory) - _LOGGER.warning( - f"Failed to save the weight_name_map to {weight_name_map_path}: {e}" - ) - return False - - return True + _LOGGER.warning(f"Failed to save the blob to {blob_path}: {e}") else: _LOGGER.warning( - f"The serialized engine {serialized_engine_size} is still larger than the available cache size {self.available_engine_cache_size}." + f"The size {blob_size} is still larger than the available cache size {self.available_engine_cache_size}." ) - return False - def load( - self, hash: str - ) -> Tuple[Optional[bytes], List[str], List[str], Optional[Dict[str, Any]]]: + def load(self, hash: str) -> Optional[bytes]: directory = os.path.join(self.engine_cache_dir, hash) if os.path.exists(directory): - # load engine - serialized_engine = None - engine_path = os.path.join(directory, "engine.trt") - if os.path.exists(engine_path): - with open(engine_path, "rb") as f: - serialized_engine = f.read() - - input_names = [] - output_names = [] - io_names_path = os.path.join(directory, "io_names.pkl") - if os.path.exists(io_names_path): - with open(io_names_path, "rb") as f: - io_names = pickle.load(f) - input_names = io_names["input_names"] - output_names = io_names["output_names"] - - # load weight_name_map - weight_name_map = None - weight_name_map_path = os.path.join(directory, "weight_name_map.pkl") - if os.path.exists(weight_name_map_path): - with open(weight_name_map_path, "rb") as f: - weight_name_map = pickle.load(f) - return serialized_engine, input_names, output_names, weight_name_map - else: - return None, [], [], {} + blob_path = os.path.join(directory, "blob.bin") + if os.path.exists(blob_path): + with open(blob_path, "rb") as f: + blob = f.read() + return blob + return None diff --git a/py/torch_tensorrt/dynamo/_settings.py b/py/torch_tensorrt/dynamo/_settings.py index 90c17d03c3..0327727c9f 100644 --- a/py/torch_tensorrt/dynamo/_settings.py +++ b/py/torch_tensorrt/dynamo/_settings.py @@ -6,6 +6,8 @@ from torch_tensorrt._enums import EngineCapability, dtype from torch_tensorrt.dynamo._defaults import ( ASSUME_DYNAMIC_SHAPE_SUPPORT, + CACHE_BUILT_ENGINES, + CUSTOM_ENGINE_CACHE, DEBUG, DISABLE_TF32, DLA_GLOBAL_DRAM_SIZE, @@ -14,13 +16,9 @@ DRYRUN, ENABLE_EXPERIMENTAL_DECOMPOSITIONS, ENABLED_PRECISIONS, - ENGINE_CACHE_DIR, - ENGINE_CACHE_INSTANCE, - ENGINE_CACHE_SIZE, ENGINE_CAPABILITY, HARDWARE_COMPATIBLE, LAZY_ENGINE_INIT, - LOAD_ENGINE_CACHE, MAKE_REFITABLE, MAX_AUX_STREAMS, MIN_BLOCK_SIZE, @@ -28,7 +26,7 @@ OPTIMIZATION_LEVEL, PASS_THROUGH_BUILD_FAILURES, REQUIRE_FULL_COMPILATION, - SAVE_ENGINE_CACHE, + REUSE_CACHED_ENGINES, SPARSE_WEIGHTS, TIMING_CACHE_PATH, TRUNCATE_DOUBLE, @@ -80,11 +78,9 @@ class CompilationSettings: output to a file if a string path is specified hardware_compatible (bool): Build the TensorRT engines compatible with GPU architectures other than that of the GPU on which the engine was built (currently works for NVIDIA Ampere and newer) timing_cache_path (str): Path to the timing cache if it exists (or) where it will be saved after compilation - save_engine_cache (bool): Whether to save the compiled TRT engines to hard disk - load_engine_cache (bool): Whether to load the compiled TRT engines from hard disk - engine_cache_dir (str): Directory to store the cached TRT engines - engine_cache_size (int): Maximum hard-disk space to use for the engine cache - engine_cache_instance (BaseEngineCache): Engine cache instance to use for saving and loading engines. Users can provide their own engine cache by inheriting from BaseEngineCache + cache_built_engines (bool): Whether to save the compiled TRT engines to storage + reuse_cached_engines (bool): Whether to load the compiled TRT engines from storage + custom_engine_cache (Optional[BaseEngineCache]): Engine cache instance to use for saving and loading engines. Users can provide their own engine cache by inheriting from BaseEngineCache """ enabled_precisions: Set[dtype] = field(default_factory=lambda: ENABLED_PRECISIONS) @@ -117,8 +113,6 @@ class CompilationSettings: hardware_compatible: bool = HARDWARE_COMPATIBLE timing_cache_path: str = TIMING_CACHE_PATH lazy_engine_init: bool = LAZY_ENGINE_INIT - save_engine_cache: bool = SAVE_ENGINE_CACHE - load_engine_cache: bool = LOAD_ENGINE_CACHE - engine_cache_dir: str = ENGINE_CACHE_DIR - engine_cache_size: int = ENGINE_CACHE_SIZE - engine_cache_instance: BaseEngineCache = ENGINE_CACHE_INSTANCE + cache_built_engines: bool = CACHE_BUILT_ENGINES + reuse_cached_engines: bool = REUSE_CACHED_ENGINES + custom_engine_cache: Optional[BaseEngineCache] = CUSTOM_ENGINE_CACHE diff --git a/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py b/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py index 86a549f2f0..91cad20f0d 100644 --- a/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py +++ b/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py @@ -474,30 +474,35 @@ def run( TRTInterpreterResult """ if ( - self.compilation_settings.save_engine_cache - or self.compilation_settings.load_engine_cache - ): - engine_cache = self.compilation_settings.engine_cache_instance - hash_val = engine_cache.get_hash(self.module) - - if self.compilation_settings.load_engine_cache: - # query the cached TRT engine - serialized_engine, input_names, output_names, weight_name_map = ( - engine_cache.load(hash_val) - ) - if serialized_engine is not None: - self._input_names = input_names - self._output_names = output_names - self.weight_name_map = weight_name_map - _LOGGER.info( - "Hit the cached TRT engine. It is loaded for skipping recompilation." - ) - return TRTInterpreterResult( - serialized_engine, - self._input_names, - self._output_names, - self.weight_name_map, - ) + self.compilation_settings.custom_engine_cache is not None + ): # custom_engine_cache could be None if this function is called from convert_exported_program_to_serialized_trt_engine etc. + if ( + self.compilation_settings.cache_built_engines + or self.compilation_settings.reuse_cached_engines + ): + engine_cache = self.compilation_settings.custom_engine_cache + hash_val = engine_cache.get_hash(self.module) + + if self.compilation_settings.reuse_cached_engines: + # query the cached TRT engine + blob = engine_cache.load(hash_val) + if blob is not None: # hit the cache + serialized_engine, input_names, output_names, weight_name_map = ( + engine_cache.unpack(blob) + ) + self._input_names = input_names + self._output_names = output_names + self.weight_name_map = weight_name_map + _LOGGER.info( + "Hit the cached TRT engine. It is loaded and skip recompilation." + ) + # TODO: refit the engine here or outside (within convert_module)? + return TRTInterpreterResult( + serialized_engine, + self._input_names, + self._output_names, + self.weight_name_map, + ) self._construct_trt_network_def() @@ -528,14 +533,17 @@ def run( self._save_timing_cache( builder_config, self.compilation_settings.timing_cache_path ) - if self.compilation_settings.save_engine_cache: - engine_cache.save( - hash_val, + if ( + self.compilation_settings.custom_engine_cache is not None + and self.compilation_settings.cache_built_engines + ): + blob = engine_cache.pack( serialized_engine, self._input_names, self._output_names, self.weight_name_map, ) + engine_cache.save(hash_val, blob) with io.BytesIO() as engine_bytes: engine_bytes.write(serialized_engine) diff --git a/py/torch_tensorrt/dynamo/utils.py b/py/torch_tensorrt/dynamo/utils.py index 1d7785717b..22ad4ab7bc 100644 --- a/py/torch_tensorrt/dynamo/utils.py +++ b/py/torch_tensorrt/dynamo/utils.py @@ -356,6 +356,22 @@ def parse_dynamo_kwargs(kwargs: Any) -> CompilationSettings: ) settings.require_full_compilation = False + # If cache_built_engines and reuse_cached_engines are True but custom_engine_cache is not provided, + # then create a default disk engine cache + if kwargs.get("cache_built_engines") or kwargs.get("reuse_cached_engines"): + if settings.custom_engine_cache is None: + from torch_tensorrt.dynamo._engine_caching import DiskEngineCache + + engine_cache_dir = kwargs.get( + "engine_cache_dir", _defaults.ENGINE_CACHE_DIR + ) + engine_cache_size = kwargs.get( + "engine_cache_size", _defaults.ENGINE_CACHE_SIZE + ) + settings.custom_engine_cache = DiskEngineCache( + engine_cache_dir, engine_cache_size + ) + logger.info("Compilation Settings: %s\n", settings) return settings