diff --git a/tools/perf/README.md b/tools/perf/README.md index 45630b4f29..452621e939 100644 --- a/tools/perf/README.md +++ b/tools/perf/README.md @@ -66,9 +66,10 @@ There are two sample configuration files added. | Name | Supported Values | Description | | ----------------- | ------------------------------------ | ------------------------------------------------------------ | -| backend | all, torch, torch_tensorrt, tensorrt | Supported backends for inference. | +| backend | all, torchscript, fx2trt, torch, torch_tensorrt, tensorrt | Supported backends for inference. "all" implies the last four methods in the list at left, and "torchscript" implies the last three (excludes fx path) | | input | - | Input binding names. Expected to list shapes of each input bindings | | model | - | Configure the model filename and name | +| model_torch | - | Name of torch model file and name (used for fx2trt) (optional) | | filename | - | Model file name to load from disk. | | name | - | Model name | | runtime | - | Runtime configurations | @@ -83,6 +84,7 @@ backend: - torch - torch_tensorrt - tensorrt + - fx2trt input: input0: - 3 @@ -92,6 +94,9 @@ input: model: filename: model.plan name: vgg16 +model_torch: + filename: model_torch.pt + name: vgg16 runtime: device: 0 precision: @@ -108,8 +113,9 @@ Note: Here are the list of `CompileSpec` options that can be provided directly to compile the pytorch module -* `--backends` : Comma separated string of backends. Eg: torch,torch_tensorrt, tensorrt or fx2trt +* `--backends` : Comma separated string of backends. Eg: torch,torch_tensorrt,tensorrt,fx2trt * `--model` : Name of the model file (Can be a torchscript module or a tensorrt engine (ending in `.plan` extension)). If the backend is `fx2trt`, the input should be a Pytorch module (instead of a torchscript module) and the options for model are (`vgg16` | `resnet50` | `efficientnet_b0`) +* `--model_torch` : Name of the PyTorch model file (optional, only necessary if fx2trt is a chosen backend) * `--inputs` : List of input shapes & dtypes. Eg: (1, 3, 224, 224)@fp32 for Resnet or (1, 128)@int32;(1, 128)@int32 for BERT * `--batch_size` : Batch size * `--precision` : Comma separated list of precisions to build TensorRT engine Eg: fp32,fp16 @@ -122,9 +128,10 @@ Eg: ``` python perf_run.py --model ${MODELS_DIR}/vgg16_scripted.jit.pt \ + --model_torch ${MODELS_DIR}/vgg16_torch.pt \ --precision fp32,fp16 --inputs="(1, 3, 224, 224)@fp32" \ --batch_size 1 \ - --backends torch,torch_tensorrt,tensorrt \ + --backends torch,torch_tensorrt,tensorrt,fx2trt \ --report "vgg_perf_bs1.txt" ``` diff --git a/tools/perf/hub.py b/tools/perf/hub.py index 3fa57eb862..a1f032212b 100644 --- a/tools/perf/hub.py +++ b/tools/perf/hub.py @@ -26,7 +26,10 @@ # Key models selected for benchmarking with their respective paths BENCHMARK_MODELS = { - "vgg16": {"model": models.vgg16(pretrained=True), "path": ["script", "pytorch"]}, + "vgg16": { + "model": models.vgg16(weights=models.VGG16_Weights.DEFAULT), + "path": ["script", "pytorch"], + }, "resnet50": { "model": models.resnet50(weights=None), "path": ["script", "pytorch"], diff --git a/tools/perf/perf_run.py b/tools/perf/perf_run.py index 610de72603..78a3be5dd5 100644 --- a/tools/perf/perf_run.py +++ b/tools/perf/perf_run.py @@ -292,6 +292,20 @@ def run( print("int8 precision expects calibration cache file for inference") return False + if (model is None) and (backend != "fx2trt"): + warnings.warn( + f"Requested backend {backend} without specifying a TorchScript Model, " + + "skipping this backend" + ) + continue + + if (model_torch is None) and (backend in ("all", "fx2trt")): + warnings.warn( + f"Requested backend {backend} without specifying a PyTorch Model, " + + "skipping this backend" + ) + continue + if backend == "all": run_torch(model, input_tensors, params, precision, batch_size) run_torch_tensorrt( @@ -311,6 +325,27 @@ def run( is_trt_engine, batch_size, ) + run_fx2trt(model_torch, input_tensors, params, precision, batch_size) + + elif backend == "torchscript": + run_torch(model, input_tensors, params, precision, batch_size) + run_torch_tensorrt( + model, + input_tensors, + params, + precision, + truncate_long_and_double, + batch_size, + ) + run_tensorrt( + model, + input_tensors, + params, + precision, + truncate_long_and_double, + is_trt_engine, + batch_size, + ) elif backend == "torch": run_torch(model, input_tensors, params, precision, batch_size) @@ -326,12 +361,6 @@ def run( ) elif backend == "fx2trt": - if model_torch is None: - warnings.warn( - "Requested backend fx2trt without specifying a PyTorch Model, " - + "skipping this backend" - ) - continue run_fx2trt(model_torch, input_tensors, params, precision, batch_size) elif backend == "tensorrt": @@ -371,9 +400,14 @@ def recordStats(backend, timings, precision, batch_size=1, compile_time_ms=None) results.append(stats) -def load_model(params): +def load_ts_model(params): model = None is_trt_engine = False + + # No TorchScript Model Specified + if len(params.get("model", "")) == 0: + return None, None, is_trt_engine + # Load torch model traced/scripted model_file = params.get("model").get("filename") try: @@ -393,6 +427,26 @@ def load_model(params): return model, model_name, is_trt_engine +def load_torch_model(params): + model = None + + # No Torch Model Specified + if len(params.get("model_torch", "")) == 0: + return None, None + + # Load torch model + model_file = params.get("model_torch").get("filename") + try: + model_name = params.get("model_torch").get("name") + except: + model_name = model_file + + print("Loading Torch model: ", model_file) + model = torch.load(model_file).cuda() + + return model, model_name + + if __name__ == "__main__": arg_parser = argparse.ArgumentParser( description="Run inference on a model with random input values" @@ -408,7 +462,9 @@ def load_model(params): type=str, help="Comma separated string of backends. Eg: torch,torch_tensorrt,fx2trt,tensorrt", ) - arg_parser.add_argument("--model", type=str, help="Name of torchscript model file") + arg_parser.add_argument( + "--model", type=str, default="", help="Name of torchscript model file" + ) arg_parser.add_argument( "--model_torch", type=str, @@ -458,7 +514,16 @@ def load_model(params): parser = ConfigParser(args.config) # Load YAML params params = parser.read_config() - model, model_name, is_trt_engine = load_model(params) + model, model_name, is_trt_engine = load_ts_model(params) + model_torch, model_name_torch = load_torch_model(params) + + # If neither model type was provided + if (model is None) and (model_torch is None): + raise ValueError( + "No valid models specified. Please provide a torchscript model file or model name " + + "(among the following options vgg16|resnet50|efficientnet_b0|vit) " + + "or provide a torch model file" + ) # Default device is set to 0. Configurable using yaml config file. torch.cuda.set_device(params.get("runtime").get("device", 0)) @@ -489,7 +554,10 @@ def load_model(params): if not is_trt_engine and (precision == "fp16" or precision == "half"): # If model is TensorRT serialized engine then model.half will report failure - model = model.half() + if model is not None: + model = model.half() + if model_torch is not None: + model_torch = model_torch.half() backends = params.get("backend") # Run inference @@ -502,6 +570,7 @@ def load_model(params): truncate_long_and_double, batch_size, is_trt_engine, + model_torch, ) else: params = vars(args) @@ -511,23 +580,27 @@ def load_model(params): model_name_torch = params["model_torch"] model_torch = None - # Load TorchScript model + # Load TorchScript model, if provided if os.path.exists(model_name): print("Loading user provided torchscript model: ", model_name) model = torch.jit.load(model_name).cuda().eval() elif model_name in BENCHMARK_MODELS: print("Loading torchscript model from BENCHMARK_MODELS for: ", model_name) model = BENCHMARK_MODELS[model_name]["model"].eval().cuda() - else: - raise ValueError( - "Invalid model name. Please provide a torchscript model file or model name (among the following options vgg16|resnet50|efficientnet_b0|vit)" - ) # Load PyTorch Model, if provided if len(model_name_torch) > 0 and os.path.exists(model_name_torch): print("Loading user provided torch model: ", model_name_torch) model_torch = torch.load(model_name_torch).eval().cuda() + # If neither model type was provided + if (model is None) and (model_torch is None): + raise ValueError( + "No valid models specified. Please provide a torchscript model file or model name " + + "(among the following options vgg16|resnet50|efficientnet_b0|vit) " + + "or provide a torch model file" + ) + backends = parse_backends(params["backends"]) truncate_long_and_double = params["truncate"] batch_size = params["batch_size"] diff --git a/tools/perf/utils.py b/tools/perf/utils.py index bd10bad798..96a13ffbc2 100644 --- a/tools/perf/utils.py +++ b/tools/perf/utils.py @@ -5,7 +5,10 @@ import timm BENCHMARK_MODELS = { - "vgg16": {"model": models.vgg16(pretrained=True), "path": ["script", "pytorch"]}, + "vgg16": { + "model": models.vgg16(weights=models.VGG16_Weights.DEFAULT), + "path": ["script", "pytorch"], + }, "resnet50": { "model": models.resnet50(weights=None), "path": ["script", "pytorch"],