Skip to content

Commit

Permalink
add bert example
Browse files Browse the repository at this point in the history
  • Loading branch information
zewenli98 committed Aug 12, 2024
1 parent 315c95c commit cb8d30b
Showing 1 changed file with 64 additions and 0 deletions.
64 changes: 64 additions & 0 deletions examples/dynamo/engine_caching_bert_example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
import numpy as np
import torch
import torch_tensorrt
from engine_caching_example import remove_timing_cache
from transformers import BertModel

np.random.seed(0)
torch.manual_seed(0)

model = BertModel.from_pretrained("bert-base-uncased", return_dict=False).cuda().eval()
inputs = [
torch.randint(0, 2, (1, 14), dtype=torch.int32).to("cuda"),
torch.randint(0, 2, (1, 14), dtype=torch.int32).to("cuda"),
]


def compile_bert(iterations=3):
times = []
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)

# The 1st iteration is to measure the compilation time without engine caching
# The 2nd and 3rd iterations are to measure the compilation time with engine caching.
# Since the 2nd iteration needs to compile and save the engine, it will be slower than the 1st iteration.
# The 3rd iteration should be faster than the 1st iteration because it loads the cached engine.
for i in range(iterations):
# remove timing cache and reset dynamo for engine caching messurement
remove_timing_cache()
torch._dynamo.reset()

if i == 0:
save_engine_cache = False
load_engine_cache = False
else:
save_engine_cache = True
load_engine_cache = True

start.record()
compilation_kwargs = {
"use_python_runtime": False,
"enabled_precisions": {torch.float},
"truncate_double": True,
"debug": True,
"min_block_size": 1,
"make_refitable": True,
"save_engine_cache": save_engine_cache,
"load_engine_cache": load_engine_cache,
"engine_cache_size": 1 << 30, # 1GB
}
optimized_model = torch.compile(
model,
backend="torch_tensorrt",
options=compilation_kwargs,
)
optimized_model(*inputs)
end.record()
torch.cuda.synchronize()
times.append(start.elapsed_time(end))

print("-----compile bert-----> compilation time:", times, "milliseconds")


if __name__ == "__main__":
compile_bert()

0 comments on commit cb8d30b

Please sign in to comment.