-
Notifications
You must be signed in to change notification settings - Fork 356
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
1 changed file
with
64 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,64 @@ | ||
import numpy as np | ||
import torch | ||
import torch_tensorrt | ||
from engine_caching_example import remove_timing_cache | ||
from transformers import BertModel | ||
|
||
np.random.seed(0) | ||
torch.manual_seed(0) | ||
|
||
model = BertModel.from_pretrained("bert-base-uncased", return_dict=False).cuda().eval() | ||
inputs = [ | ||
torch.randint(0, 2, (1, 14), dtype=torch.int32).to("cuda"), | ||
torch.randint(0, 2, (1, 14), dtype=torch.int32).to("cuda"), | ||
] | ||
|
||
|
||
def compile_bert(iterations=3): | ||
times = [] | ||
start = torch.cuda.Event(enable_timing=True) | ||
end = torch.cuda.Event(enable_timing=True) | ||
|
||
# The 1st iteration is to measure the compilation time without engine caching | ||
# The 2nd and 3rd iterations are to measure the compilation time with engine caching. | ||
# Since the 2nd iteration needs to compile and save the engine, it will be slower than the 1st iteration. | ||
# The 3rd iteration should be faster than the 1st iteration because it loads the cached engine. | ||
for i in range(iterations): | ||
# remove timing cache and reset dynamo for engine caching messurement | ||
remove_timing_cache() | ||
torch._dynamo.reset() | ||
|
||
if i == 0: | ||
save_engine_cache = False | ||
load_engine_cache = False | ||
else: | ||
save_engine_cache = True | ||
load_engine_cache = True | ||
|
||
start.record() | ||
compilation_kwargs = { | ||
"use_python_runtime": False, | ||
"enabled_precisions": {torch.float}, | ||
"truncate_double": True, | ||
"debug": True, | ||
"min_block_size": 1, | ||
"make_refitable": True, | ||
"save_engine_cache": save_engine_cache, | ||
"load_engine_cache": load_engine_cache, | ||
"engine_cache_size": 1 << 30, # 1GB | ||
} | ||
optimized_model = torch.compile( | ||
model, | ||
backend="torch_tensorrt", | ||
options=compilation_kwargs, | ||
) | ||
optimized_model(*inputs) | ||
end.record() | ||
torch.cuda.synchronize() | ||
times.append(start.elapsed_time(end)) | ||
|
||
print("-----compile bert-----> compilation time:", times, "milliseconds") | ||
|
||
|
||
if __name__ == "__main__": | ||
compile_bert() |