Skip to content

Commit

Permalink
feat: Add functionality for easily benchmarking fx
Browse files Browse the repository at this point in the history
- Add fx path in benchmarking code
- Add fx saving tools to `utils` and `hub`
- Add PyTorch model parsing and loading in `perf_run` script
  • Loading branch information
gs-olive committed Dec 2, 2022
1 parent 2b1cedf commit 2b130ec
Show file tree
Hide file tree
Showing 4 changed files with 104 additions and 22 deletions.
9 changes: 6 additions & 3 deletions tools/perf/benchmark.sh
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,10 @@ echo "Benchmarking VGG16 model"
for bs in ${batch_sizes[@]}
do
python perf_run.py --model ${MODELS_DIR}/vgg16_scripted.jit.pt \
--model_torch ${MODELS_DIR}/vgg16_pytorch.pt \
--precision fp32,fp16 --inputs="(${bs}, 3, 224, 224)" \
--batch_size ${bs} \
--backends torch,torch_tensorrt,tensorrt \
--backends torch,torch_tensorrt,tensorrt,fx2trt \
--report "vgg_perf_bs${bs}.txt"
done

Expand All @@ -23,9 +24,10 @@ echo "Benchmarking Resnet50 model"
for bs in ${batch_sizes[@]}
do
python perf_run.py --model ${MODELS_DIR}/resnet50_scripted.jit.pt \
--model_torch ${MODELS_DIR}/resnet50_pytorch.pt \
--precision fp32,fp16 --inputs="(${bs}, 3, 224, 224)" \
--batch_size ${bs} \
--backends torch,torch_tensorrt,tensorrt \
--backends torch,torch_tensorrt,tensorrt,fx2trt \
--report "rn50_perf_bs${bs}.txt"
done

Expand All @@ -45,9 +47,10 @@ echo "Benchmarking EfficientNet-B0 model"
for bs in ${batch_sizes[@]}
do
python perf_run.py --model ${MODELS_DIR}/efficientnet_b0_scripted.jit.pt \
--model_torch ${MODELS_DIR}/efficientnet_b0_pytorch.pt \
--precision fp32,fp16 --inputs="(${bs}, 3, 224, 224)" \
--batch_size ${bs} \
--backends torch,torch_tensorrt,tensorrt \
--backends torch,torch_tensorrt,tensorrt,fx2trt \
--report "eff_b0_perf_bs${bs}.txt"
done

Expand Down
75 changes: 63 additions & 12 deletions tools/perf/hub.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,19 @@
# Downloads all model files again if manifest file is not present
MANIFEST_FILE = "model_manifest.json"

# Valid paths for model-saving specification
VALID_PATHS = ("script", "trace", "torchscript", "pytorch", "all")

# Key models selected for benchmarking with their respective paths
BENCHMARK_MODELS = {
"vgg16": {"model": models.vgg16(weights=None), "path": "script"},
"resnet50": {"model": models.resnet50(weights=None), "path": "script"},
"vgg16": {"model": models.vgg16(pretrained=True), "path": ["script", "pytorch"]},
"resnet50": {
"model": models.resnet50(weights=None),
"path": ["script", "pytorch"],
},
"efficientnet_b0": {
"model": timm.create_model("efficientnet_b0", pretrained=True),
"path": "script",
"path": ["script", "pytorch"],
},
"vit": {
"model": timm.create_model("vit_base_patch16_224", pretrained=True),
Expand All @@ -40,18 +47,26 @@ def get(n, m, manifest):
print("Downloading {}".format(n))
traced_filename = "models/" + n + "_traced.jit.pt"
script_filename = "models/" + n + "_scripted.jit.pt"
pytorch_filename = "models/" + n + "_pytorch.pt"
x = torch.ones((1, 3, 300, 300)).cuda()
if n == "bert-base-uncased":
if n == "bert_base_uncased":
traced_model = m["model"]
torch.jit.save(traced_model, traced_filename)
manifest.update({n: [traced_filename]})
else:
m["model"] = m["model"].eval().cuda()
if m["path"] == "both" or m["path"] == "trace":

# Get all desired model save specifications as list
paths = [m["path"]] if isinstance(m["path"], str) else m["path"]

# Depending on specified model save specifications, save desired model formats
if any(path in ("all", "torchscript", "trace") for path in paths):
# (TorchScript) Traced model
trace_model = torch.jit.trace(m["model"], [x])
torch.jit.save(trace_model, traced_filename)
manifest.update({n: [traced_filename]})
if m["path"] == "both" or m["path"] == "script":
if any(path in ("all", "torchscript", "script") for path in paths):
# (TorchScript) Scripted model
script_model = torch.jit.script(m["model"])
torch.jit.save(script_model, script_filename)
if n in manifest.keys():
Expand All @@ -60,6 +75,15 @@ def get(n, m, manifest):
manifest.update({n: files})
else:
manifest.update({n: [script_filename]})
if any(path in ("all", "pytorch") for path in paths):
# (PyTorch Module) model
torch.save(m["model"], pytorch_filename)
if n in manifest.keys():
files = list(manifest[n]) if type(manifest[n]) != list else manifest[n]
files.append(script_filename)
manifest.update({n: files})
else:
manifest.update({n: [script_filename]})
return manifest


Expand All @@ -72,15 +96,35 @@ def download_models(version_matches, manifest):
for n, m in BENCHMARK_MODELS.items():
scripted_filename = "models/" + n + "_scripted.jit.pt"
traced_filename = "models/" + n + "_traced.jit.pt"
pytorch_filename = "models/" + n + "_pytorch.pt"
# Check if model file exists on disk

# Extract model specifications as list and ensure all desired formats exist
paths = [m["path"]] if isinstance(m["path"], str) else m["path"]
if (
(
m["path"] == "both"
any(path == "all" for path in paths)
and os.path.exists(scripted_filename)
and os.path.exists(traced_filename)
and os.path.exists(pytorch_filename)
)
or (
any(path == "torchscript" for path in paths)
and os.path.exists(scripted_filename)
and os.path.exists(traced_filename)
)
or (m["path"] == "script" and os.path.exists(scripted_filename))
or (m["path"] == "trace" and os.path.exists(traced_filename))
or (
any(path == "script" for path in paths)
and os.path.exists(scripted_filename)
)
or (
any(path == "trace" for path in paths)
and os.path.exists(traced_filename)
)
or (
any(path == "pytorch" for path in paths)
and os.path.exists(pytorch_filename)
)
):
print("Skipping {} ".format(n))
continue
Expand All @@ -90,7 +134,6 @@ def download_models(version_matches, manifest):
def main():
manifest = None
version_matches = False
manifest_exists = False

# Check if Manifest file exists or is empty
if not os.path.exists(MANIFEST_FILE) or os.stat(MANIFEST_FILE).st_size == 0:
Expand All @@ -99,7 +142,6 @@ def main():
# Creating an empty manifest file for overwriting post setup
os.system("touch {}".format(MANIFEST_FILE))
else:
manifest_exists = True

# Load manifest if already exists
with open(MANIFEST_FILE, "r") as f:
Expand Down Expand Up @@ -129,4 +171,13 @@ def main():
f.truncate()


main()
if __name__ == "__main__":
# Ensure all specified desired model formats exist and are valid
paths = [
[m["path"]] if isinstance(m["path"], str) else m["path"]
for m in BENCHMARK_MODELS.values()
]
assert all(
(path in VALID_PATHS) for path_list in paths for path in path_list
), "Not all 'path' attributes in BENCHMARK_MODELS are valid"
main()
34 changes: 31 additions & 3 deletions tools/perf/perf_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import time
import timeit
import warnings
import numpy as np
import torch.backends.cudnn as cudnn

Expand Down Expand Up @@ -147,6 +148,7 @@ def run_fx2trt(model, input_tensors, params, precision, batch_size):
max_batch_size=batch_size,
lower_precision=precision,
verbose_log=False,
explicit_batch_dimension=True,
)
end_compile = time.time_ns()
compile_time_ms = (end_compile - start_compile) / 1e6
Expand Down Expand Up @@ -272,6 +274,7 @@ def run(
truncate_long_and_double=False,
batch_size=1,
is_trt_engine=False,
model_torch=None,
):
for backend in backends:
if precision == "int8":
Expand Down Expand Up @@ -323,7 +326,13 @@ def run(
)

elif backend == "fx2trt":
run_fx2trt(model, input_tensors, params, precision, batch_size)
if model_torch is None:
warnings.warn(
"Requested backend fx2trt without specifying a PyTorch Model, "
+ "skipping this backend"
)
continue
run_fx2trt(model_torch, input_tensors, params, precision, batch_size)

elif backend == "tensorrt":
run_tensorrt(
Expand Down Expand Up @@ -399,7 +408,13 @@ def load_model(params):
type=str,
help="Comma separated string of backends. Eg: torch,torch_tensorrt,fx2trt,tensorrt",
)
arg_parser.add_argument("--model", type=str, help="Name of the model file")
arg_parser.add_argument("--model", type=str, help="Name of torchscript model file")
arg_parser.add_argument(
"--model_torch",
type=str,
default="",
help="Name of torch model file (used for fx2trt)",
)
arg_parser.add_argument(
"--inputs",
type=str,
Expand Down Expand Up @@ -491,16 +506,28 @@ def load_model(params):
else:
params = vars(args)
model_name = params["model"]
model = None

model_name_torch = params["model_torch"]
model_torch = None

# Load TorchScript model
if os.path.exists(model_name):
print("Loading user provided model: ", model_name)
print("Loading user provided torchscript model: ", model_name)
model = torch.jit.load(model_name).cuda().eval()
elif model_name in BENCHMARK_MODELS:
print("Loading torchscript model from BENCHMARK_MODELS for: ", model_name)
model = BENCHMARK_MODELS[model_name]["model"].eval().cuda()
else:
raise ValueError(
"Invalid model name. Please provide a torchscript model file or model name (among the following options vgg16|resnet50|efficientnet_b0|vit)"
)

# Load PyTorch Model, if provided
if len(model_name_torch) > 0 and os.path.exists(model_name_torch):
print("Loading user provided torch model: ", model_name_torch)
model_torch = torch.load(model_name_torch).eval().cuda()

backends = parse_backends(params["backends"])
truncate_long_and_double = params["truncate"]
batch_size = params["batch_size"]
Expand All @@ -523,6 +550,7 @@ def load_model(params):
truncate_long_and_double,
batch_size,
is_trt_engine,
model_torch=model_torch,
)

# Generate report
Expand Down
8 changes: 4 additions & 4 deletions tools/perf/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,14 @@
import timm

BENCHMARK_MODELS = {
"vgg16": {"model": models.vgg16(pretrained=True), "path": "script"},
"vgg16": {"model": models.vgg16(pretrained=True), "path": ["script", "pytorch"]},
"resnet50": {
"model": torch.hub.load("pytorch/vision:v0.9.0", "resnet50", pretrained=True),
"path": "script",
"model": models.resnet50(weights=None),
"path": ["script", "pytorch"],
},
"efficientnet_b0": {
"model": timm.create_model("efficientnet_b0", pretrained=True),
"path": "script",
"path": ["script", "pytorch"],
},
"vit": {
"model": timm.create_model("vit_base_patch16_224", pretrained=True),
Expand Down

0 comments on commit 2b130ec

Please sign in to comment.