-
Notifications
You must be signed in to change notification settings - Fork 354
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #1254 from pytorch/perf_changes
feat(//tools/perf): Refactor perf_run.py, add fx2trt backend support, usage via CLI arguments
- Loading branch information
Showing
7 changed files
with
633 additions
and
131 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,64 @@ | ||
#!/bin/bash | ||
|
||
MODELS_DIR="models" | ||
|
||
# Download the Torchscript models | ||
python hub.py | ||
|
||
batch_sizes=(1 2 4 8 16 32 64 128 256) | ||
|
||
#Benchmark VGG16 model | ||
echo "Benchmarking VGG16 model" | ||
for bs in ${batch_sizes[@]} | ||
do | ||
python perf_run.py --model ${MODELS_DIR}/vgg16_scripted.jit.pt \ | ||
--precision fp32,fp16 --inputs="(${bs}, 3, 224, 224)" \ | ||
--batch_size ${bs} \ | ||
--backends torch,torch_tensorrt,tensorrt \ | ||
--report "vgg_perf_bs${bs}.txt" | ||
done | ||
|
||
# Benchmark Resnet50 model | ||
echo "Benchmarking Resnet50 model" | ||
for bs in ${batch_sizes[@]} | ||
do | ||
python perf_run.py --model ${MODELS_DIR}/resnet50_scripted.jit.pt \ | ||
--precision fp32,fp16 --inputs="(${bs}, 3, 224, 224)" \ | ||
--batch_size ${bs} \ | ||
--backends torch,torch_tensorrt,tensorrt \ | ||
--report "rn50_perf_bs${bs}.txt" | ||
done | ||
|
||
# Benchmark VIT model | ||
echo "Benchmarking VIT model" | ||
for bs in ${batch_sizes[@]} | ||
do | ||
python perf_run.py --model ${MODELS_DIR}/vit_scripted.jit.pt \ | ||
--precision fp32,fp16 --inputs="(${bs}, 3, 224, 224)" \ | ||
--batch_size ${bs} \ | ||
--backends torch,torch_tensorrt,tensorrt \ | ||
--report "vit_perf_bs${bs}.txt" | ||
done | ||
|
||
# Benchmark EfficientNet-B0 model | ||
echo "Benchmarking EfficientNet-B0 model" | ||
for bs in ${batch_sizes[@]} | ||
do | ||
python perf_run.py --model ${MODELS_DIR}/efficientnet_b0_scripted.jit.pt \ | ||
--precision fp32,fp16 --inputs="(${bs}, 3, 224, 224)" \ | ||
--batch_size ${bs} \ | ||
--backends torch,torch_tensorrt,tensorrt \ | ||
--report "eff_b0_perf_bs${bs}.txt" | ||
done | ||
|
||
# Benchmark BERT model | ||
echo "Benchmarking Huggingface BERT base model" | ||
for bs in ${batch_sizes[@]} | ||
do | ||
python perf_run.py --model ${MODELS_DIR}/bert_base_uncased_traced.jit.pt \ | ||
--precision fp32 --inputs="(${bs}, 128)@int32;(${bs}, 128)@int32" \ | ||
--batch_size ${bs} \ | ||
--backends torch,torch_tensorrt \ | ||
--truncate \ | ||
--report "bert_base_perf_bs${bs}.txt" | ||
done |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,30 @@ | ||
import torch | ||
import torch.nn as nn | ||
from transformers import BertModel, BertTokenizer, BertConfig | ||
import torch.nn.functional as F | ||
|
||
|
||
def BertModule(): | ||
model_name = "bert-base-uncased" | ||
enc = BertTokenizer.from_pretrained(model_name) | ||
text = "[CLS] Who was Jim Henson ? [SEP] Jim Henson was a puppeteer [SEP]" | ||
tokenized_text = enc.tokenize(text) | ||
masked_index = 8 | ||
tokenized_text[masked_index] = "[MASK]" | ||
indexed_tokens = enc.convert_tokens_to_ids(tokenized_text) | ||
segments_ids = [0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1] | ||
tokens_tensor = torch.tensor([indexed_tokens]) | ||
segments_tensors = torch.tensor([segments_ids]) | ||
config = BertConfig( | ||
vocab_size_or_config_json_file=32000, | ||
hidden_size=768, | ||
num_hidden_layers=12, | ||
num_attention_heads=12, | ||
intermediate_size=3072, | ||
torchscript=True, | ||
) | ||
model = BertModel(config) | ||
model.eval() | ||
model = BertModel.from_pretrained(model_name, torchscript=True) | ||
traced_model = torch.jit.trace(model, [tokens_tensor, segments_tensors]) | ||
return traced_model |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,132 @@ | ||
import torch | ||
import torch.nn as nn | ||
import torch.nn.functional as F | ||
import torchvision.models as models | ||
import timm | ||
from transformers import BertModel, BertTokenizer, BertConfig | ||
import os | ||
import json | ||
import custom_models as cm | ||
|
||
torch.hub._validate_not_a_forked_repo = lambda a, b, c: True | ||
|
||
torch_version = torch.__version__ | ||
|
||
# Detect case of no GPU before deserialization of models on GPU | ||
if not torch.cuda.is_available(): | ||
raise Exception( | ||
"No GPU found. Please check if installed torch version is compatible with CUDA version" | ||
) | ||
|
||
# Downloads all model files again if manifest file is not present | ||
MANIFEST_FILE = "model_manifest.json" | ||
|
||
BENCHMARK_MODELS = { | ||
"vgg16": {"model": models.vgg16(weights=None), "path": "script"}, | ||
"resnet50": {"model": models.resnet50(weights=None), "path": "script"}, | ||
"efficientnet_b0": { | ||
"model": timm.create_model("efficientnet_b0", pretrained=True), | ||
"path": "script", | ||
}, | ||
"vit": { | ||
"model": timm.create_model("vit_base_patch16_224", pretrained=True), | ||
"path": "script", | ||
}, | ||
"bert_base_uncased": {"model": cm.BertModule(), "path": "trace"}, | ||
} | ||
|
||
|
||
def get(n, m, manifest): | ||
print("Downloading {}".format(n)) | ||
traced_filename = "models/" + n + "_traced.jit.pt" | ||
script_filename = "models/" + n + "_scripted.jit.pt" | ||
x = torch.ones((1, 3, 300, 300)).cuda() | ||
if n == "bert-base-uncased": | ||
traced_model = m["model"] | ||
torch.jit.save(traced_model, traced_filename) | ||
manifest.update({n: [traced_filename]}) | ||
else: | ||
m["model"] = m["model"].eval().cuda() | ||
if m["path"] == "both" or m["path"] == "trace": | ||
trace_model = torch.jit.trace(m["model"], [x]) | ||
torch.jit.save(trace_model, traced_filename) | ||
manifest.update({n: [traced_filename]}) | ||
if m["path"] == "both" or m["path"] == "script": | ||
script_model = torch.jit.script(m["model"]) | ||
torch.jit.save(script_model, script_filename) | ||
if n in manifest.keys(): | ||
files = list(manifest[n]) if type(manifest[n]) != list else manifest[n] | ||
files.append(script_filename) | ||
manifest.update({n: files}) | ||
else: | ||
manifest.update({n: [script_filename]}) | ||
return manifest | ||
|
||
|
||
def download_models(version_matches, manifest): | ||
# Download all models if torch version is different than model version | ||
if not version_matches: | ||
for n, m in BENCHMARK_MODELS.items(): | ||
manifest = get(n, m, manifest) | ||
else: | ||
for n, m in BENCHMARK_MODELS.items(): | ||
scripted_filename = "models/" + n + "_scripted.jit.pt" | ||
traced_filename = "models/" + n + "_traced.jit.pt" | ||
# Check if model file exists on disk | ||
if ( | ||
( | ||
m["path"] == "both" | ||
and os.path.exists(scripted_filename) | ||
and os.path.exists(traced_filename) | ||
) | ||
or (m["path"] == "script" and os.path.exists(scripted_filename)) | ||
or (m["path"] == "trace" and os.path.exists(traced_filename)) | ||
): | ||
print("Skipping {} ".format(n)) | ||
continue | ||
manifest = get(n, m, manifest) | ||
|
||
|
||
def main(): | ||
manifest = None | ||
version_matches = False | ||
manifest_exists = False | ||
|
||
# Check if Manifest file exists or is empty | ||
if not os.path.exists(MANIFEST_FILE) or os.stat(MANIFEST_FILE).st_size == 0: | ||
manifest = {"version": torch_version} | ||
|
||
# Creating an empty manifest file for overwriting post setup | ||
os.system("touch {}".format(MANIFEST_FILE)) | ||
else: | ||
manifest_exists = True | ||
|
||
# Load manifest if already exists | ||
with open(MANIFEST_FILE, "r") as f: | ||
manifest = json.load(f) | ||
if manifest["version"] == torch_version: | ||
version_matches = True | ||
else: | ||
print( | ||
"Torch version: {} mismatches \ | ||
with manifest's version: {}. Re-downloading \ | ||
all models".format( | ||
torch_version, manifest["version"] | ||
) | ||
) | ||
|
||
# Overwrite the manifest version as current torch version | ||
manifest["version"] = torch_version | ||
|
||
download_models(version_matches, manifest) | ||
|
||
# Write updated manifest file to disk | ||
with open(MANIFEST_FILE, "r+") as f: | ||
data = f.read() | ||
f.seek(0) | ||
record = json.dumps(manifest) | ||
f.write(record) | ||
f.truncate() | ||
|
||
|
||
main() |
Oops, something went wrong.