-
Notifications
You must be signed in to change notification settings - Fork 354
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
7 changed files
with
365 additions
and
12 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,145 @@ | ||
import time | ||
|
||
import numpy as np | ||
import torch | ||
import torch_tensorrt as torch_trt | ||
import torchvision.models as models | ||
from torch.fx.experimental.proxy_tensor import maybe_disable_fake_tensor_mode | ||
|
||
np.random.seed(0) | ||
torch.manual_seed(0) | ||
size = (100, 3, 224, 224) | ||
inputs = [torch.rand(size).to("cuda")] | ||
|
||
model = models.resnet18(pretrained=True).eval().to("cuda") | ||
exp_program = torch.export.export(model, tuple(inputs)) | ||
enabled_precisions = {torch.float} | ||
debug = False | ||
workspace_size = 20 << 30 | ||
min_block_size = 0 | ||
use_python_runtime = False | ||
torch_executed_ops = {} | ||
|
||
|
||
def dynamo_path(): | ||
############### warmup ############### | ||
inputs = [torch.rand(size).to("cuda")] | ||
t1 = time.time() | ||
trt_gm = torch_trt.dynamo.compile( | ||
exp_program, | ||
tuple(inputs), | ||
use_python_runtime=use_python_runtime, | ||
enabled_precisions=enabled_precisions, | ||
debug=debug, | ||
min_block_size=min_block_size, | ||
torch_executed_ops=torch_executed_ops, | ||
make_refitable=True, | ||
ignore_engine_cache=True, | ||
) # Output is a torch.fx.GraphModule | ||
t2 = time.time() | ||
|
||
############### compile for the first time ############### | ||
inputs = [torch.rand(size).to("cuda")] | ||
t3 = time.time() | ||
trt_gm1 = torch_trt.dynamo.compile( | ||
exp_program, | ||
tuple(inputs), | ||
use_python_runtime=use_python_runtime, | ||
enabled_precisions=enabled_precisions, | ||
debug=debug, | ||
min_block_size=min_block_size, | ||
torch_executed_ops=torch_executed_ops, | ||
make_refitable=True, | ||
ignore_engine_cache=False, | ||
) # Output is a torch.fx.GraphModule | ||
t4 = time.time() | ||
# Check the output | ||
outputs = trt_gm1(*inputs) | ||
print("----------> 1st output:", outputs) | ||
|
||
############### compile for the second time ############### | ||
inputs = [torch.rand(size).to("cuda")] | ||
t5 = time.time() | ||
trt_gm2 = torch_trt.dynamo.compile( | ||
exp_program, | ||
tuple(inputs), | ||
use_python_runtime=use_python_runtime, | ||
enabled_precisions=enabled_precisions, | ||
debug=debug, | ||
min_block_size=min_block_size, | ||
torch_executed_ops=torch_executed_ops, | ||
make_refitable=True, | ||
ignore_engine_cache=False, | ||
) # Output is a torch.fx.GraphModule | ||
t6 = time.time() | ||
# Check the output | ||
outputs = trt_gm2(*inputs) | ||
print("----------> 2nd output:", outputs) | ||
|
||
print("----------> warmup compilation time:", t2 - t1, "seconds") | ||
print("----------> 1st compilation time:", t4 - t3, "seconds") | ||
print("----------> 2nd compilation time:", t6 - t5, "seconds") | ||
|
||
|
||
def compile_path(): | ||
inputs = [torch.rand(size).to("cuda")] | ||
model = models.resnet18(pretrained=True).eval().to("cuda") | ||
t1 = time.time() | ||
model = torch.compile( | ||
model, | ||
backend="tensorrt", | ||
options={ | ||
"use_python_runtime": use_python_runtime, | ||
"enabled_precisions": enabled_precisions, | ||
"debug": debug, | ||
"min_block_size": min_block_size, | ||
"torch_executed_ops": torch_executed_ops, | ||
"make_refitable": True, | ||
"ignore_engine_cache": True, | ||
}, | ||
) | ||
t2 = time.time() | ||
print("---------->", model(*inputs)) | ||
|
||
t3 = time.time() | ||
model1 = torch.compile( | ||
model, | ||
backend="tensorrt", | ||
options={ | ||
"use_python_runtime": use_python_runtime, | ||
"enabled_precisions": enabled_precisions, | ||
"debug": debug, | ||
"min_block_size": min_block_size, | ||
"torch_executed_ops": torch_executed_ops, | ||
"make_refitable": True, | ||
"ignore_engine_cache": False, | ||
}, | ||
) | ||
t4 = time.time() | ||
print("----------> 1st output:", model1(*inputs)) | ||
|
||
t5 = time.time() | ||
model2 = torch.compile( | ||
model, | ||
backend="tensorrt", | ||
options={ | ||
"use_python_runtime": use_python_runtime, | ||
"enabled_precisions": enabled_precisions, | ||
"debug": debug, | ||
"min_block_size": min_block_size, | ||
"torch_executed_ops": torch_executed_ops, | ||
"make_refitable": True, | ||
"ignore_engine_cache": False, | ||
}, | ||
) | ||
t6 = time.time() | ||
print("----------> 2nd output:", model2(*inputs)) | ||
|
||
print("----------> warmup compilation time:", t2 - t1, "seconds") | ||
print("----------> 1st compilation time:", t4 - t3, "seconds") | ||
print("----------> 2nd compilation time:", t6 - t5, "seconds") | ||
|
||
|
||
if __name__ == "__main__": | ||
dynamo_path() | ||
compile_path() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,152 @@ | ||
import ast | ||
import copy | ||
import logging | ||
import os | ||
from abc import ABC, abstractmethod | ||
from typing import List, Optional, Tuple, cast | ||
|
||
import torch | ||
from torch._inductor.codecache import FxGraphCachePickler | ||
from torch.fx.experimental.proxy_tensor import maybe_disable_fake_tensor_mode | ||
from torch_tensorrt.dynamo._defaults import ENGINE_CACHE_DIR, ENGINE_CACHE_SIZE | ||
|
||
_LOGGER: logging.Logger = logging.getLogger(__name__) | ||
|
||
|
||
class BaseEngineCache(ABC): | ||
|
||
@staticmethod | ||
def get_hash(gm: torch.fx.GraphModule) -> str: | ||
"""Get the hash value of the GraphModule | ||
Args: | ||
gm (torch.fx.GraphModule): GraphModule to hash | ||
Returns: | ||
str: hash value of the GraphModule | ||
""" | ||
# parameters are set to 0 | ||
with maybe_disable_fake_tensor_mode(): | ||
new_gm = copy.deepcopy(gm) | ||
for name, param in new_gm.named_parameters(): | ||
param.data.zero_() | ||
|
||
hash_val = cast(str, FxGraphCachePickler.get_hash(gm)) | ||
|
||
return hash_val | ||
|
||
@abstractmethod | ||
def save( | ||
self, | ||
hash: str, | ||
serialized_engine: bytes, | ||
input_names: List[str], | ||
output_names: List[str], | ||
) -> None: | ||
"""Save the serialized engine to hard disk | ||
Args: | ||
hash (str): hash value of the GraphModule | ||
serialized_engine (bytes): serialized TRT engine | ||
input_names (List[str]): input names of TRT engine | ||
output_names (List[str]): output names of TRT engine | ||
Returns: | ||
None | ||
""" | ||
pass | ||
|
||
@abstractmethod | ||
def load(self, hash: str) -> Tuple[Optional[bytes], List[str], List[str]]: | ||
"""Load the serialized engine from hard disk | ||
Args: | ||
hash (str): hash value of the GraphModule | ||
Returns: | ||
Sequence[Optional[bytes], List[str], List[str]]: serialized TRT engine, input names of TRT Engine, output names of TRT Engine | ||
""" | ||
pass | ||
|
||
@abstractmethod | ||
def clear_cache(self, size: int) -> None: | ||
"""Clear the cache to make sure at least `size` bytes are available | ||
Args: | ||
size (int): the needed size | ||
""" | ||
pass | ||
|
||
|
||
class EngineCache(BaseEngineCache): | ||
|
||
def __init__( | ||
self, | ||
engine_cache_size: int = ENGINE_CACHE_SIZE, | ||
engine_cache_dir: str = ENGINE_CACHE_DIR, | ||
) -> None: | ||
self.total_engine_cache_size = engine_cache_size | ||
self.available_engine_cache_size = engine_cache_size | ||
self.engine_cache_dir = engine_cache_dir | ||
|
||
def has_available_cache_size(self, serialized_engine: bytes) -> bool: | ||
"""Check if the cache has available space for saving the serialized engine | ||
Args: | ||
serialized_engine (bytes): serialized TRT engine | ||
Returns: | ||
bool: whether the cache has available size for the serialized engine | ||
""" | ||
return len(serialized_engine) <= self.available_engine_cache_size | ||
|
||
def clear_cache(self, size: int) -> None: | ||
|
||
def LRU() -> None: | ||
pass | ||
|
||
pass | ||
|
||
def save( | ||
self, | ||
hash: str, | ||
serialized_engine: bytes, | ||
input_names: List[str], | ||
output_names: List[str], | ||
) -> None: | ||
serialized_engine_size = len(serialized_engine) | ||
if serialized_engine_size <= self.total_engine_cache_size: | ||
_LOGGER.warning( | ||
f"The serialized engine cannot be saved because the size of the engine {serialized_engine_size} is larger than the total cache size {self.total_engine_cache_size}." | ||
) | ||
return | ||
|
||
if not self.has_available_cache_size(serialized_engine): | ||
self.clear_cache(serialized_engine_size) | ||
|
||
path = os.path.join( | ||
self.engine_cache_dir, f"{hash}/engine_{input_names}_{output_names}.trt" | ||
) | ||
os.makedirs(os.path.dirname(path), exist_ok=True) | ||
with open(path, "wb") as f: | ||
f.write(serialized_engine) | ||
_LOGGER.info(f"A TRT engine was cached to {path}") | ||
|
||
def load(self, hash: str) -> Tuple[Optional[bytes], List[str], List[str]]: | ||
directory = os.path.join(self.engine_cache_dir, hash) | ||
if os.path.exists(directory): | ||
engine_list = os.listdir(directory) | ||
assert ( | ||
len(engine_list) == 1 | ||
), f"There are more than one engine {engine_list} under {directory}." | ||
path = os.path.join(directory, engine_list[0]) | ||
input_names_str, output_names_str = ( | ||
engine_list[0].split(".")[0].split("_")[1:] | ||
) | ||
input_names = ast.literal_eval(input_names_str) | ||
output_names = ast.literal_eval(output_names_str) | ||
with open(path, "rb") as f: | ||
serialized_engine = f.read() | ||
return serialized_engine, input_names, output_names | ||
else: | ||
return None, [], [] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.