From 98527d2acb587e41900d20c3e9de8777c955a5c0 Mon Sep 17 00:00:00 2001 From: Naren Dasan Date: Wed, 27 May 2020 18:45:06 -0700 Subject: [PATCH] fix(//cpp/benchmark): reorder benchmark so FP16 bn issue in JIT doesnt interfere with TRTorch Signed-off-by: Naren Dasan Signed-off-by: Naren Dasan --- cpp/benchmark/README.md | 6 ++++-- cpp/benchmark/main.cpp | 42 +++++++++++++++++++++++++---------------- 2 files changed, 30 insertions(+), 18 deletions(-) diff --git a/cpp/benchmark/README.md b/cpp/benchmark/README.md index 73041c9ff3..1c45bc9fbe 100644 --- a/cpp/benchmark/README.md +++ b/cpp/benchmark/README.md @@ -1,6 +1,6 @@ # Benchmarking -This is a quick benchmarking application for TRTorch. It lets you run supported TorchScript modules both in JIT and TRT and returns the average runtime and throughput. +This is a quick benchmarking application for TRTorch. It lets you run supported TorchScript modules both in JIT and TRT and returns the average runtime and throughput. ## Compilation / Usage @@ -20,7 +20,7 @@ bazel run //cpp/benchmark --cxxopt="-DNDEBUG" --cxxopt="-DJIT" --cxxopt="-DTRT" ### Options -You can run a module with JIT or TRT via TRTorch in either FP32 or FP16. These options are controlled by preprocessor directives. +You can run a module with JIT or TRT via TRTorch in either FP32 or FP16. These options are controlled by preprocessor directives. - To enable JIT profiling, add the argument `--cxxopt="-DJIT"` @@ -28,4 +28,6 @@ You can run a module with JIT or TRT via TRTorch in either FP32 or FP16. These o - To enable FP16 execution, add the argument `--cxxopt="-DHALF"` +- To also save the TRT engine, add the argument `--cxxopt="-DSAVE_ENGINE"` + > It's suggested to also define `--cxxopt="-DNDEBUG"` to supress debug information diff --git a/cpp/benchmark/main.cpp b/cpp/benchmark/main.cpp index 05c84cf7ac..e73f1da4e8 100644 --- a/cpp/benchmark/main.cpp +++ b/cpp/benchmark/main.cpp @@ -105,15 +105,6 @@ int main(int argc, const char* argv[]) { mod.to(at::kCUDA); -#ifdef HALF - mod.to(torch::kHalf); - for (auto layer : mod.named_modules()) { - if (layer.name.find(".bn") != std::string::npos) { - layer.value.to(torch::kFloat); - } - } -#endif - std::vector> dims; for (int i = 2; i < argc; i++) { auto arg = std::string(argv[i]); @@ -129,23 +120,42 @@ int main(int argc, const char* argv[]) { at::globalContext().setBenchmarkCuDNN(true); -#ifdef JIT - auto jit_runtimes = benchmark_module(mod, dims[0]); - print_avg_std_dev("JIT", jit_runtimes, dims[0][0]); -#endif - #ifdef TRT auto extra_info = trtorch::ExtraInfo(dims); - extra_info.workspace_size = 1 << 24; + extra_info.workspace_size = 1 << 20; #ifdef HALF - extra_info.op_precision = at::kHalf; + extra_info.op_precision = torch::kF16; #endif auto trt_mod = trtorch::CompileGraph(mod, extra_info); + +#ifdef SAVE_ENGINE + std::cout << "Compiling graph to save as TRT engine (/tmp/engine_converted_from_jit.trt)" << std::endl; + auto engine = trtorch::ConvertGraphToTRTEngine(mod, "forward", extra_info); + std::ofstream out("/tmp/engine_converted_from_jit.trt"); + out << engine; + out.close(); +#endif + auto trt_runtimes = benchmark_module(trt_mod, dims[0]); print_avg_std_dev("JIT/TRT", trt_runtimes, dims[0][0]); #endif + +#ifdef HALF + mod.to(torch::kHalf); + for (auto layer : mod.named_modules()) { + if (layer.name.find(".bn") != std::string::npos) { + layer.value.to(torch::kFloat); + } + } +#endif + +#ifdef JIT + auto jit_runtimes = benchmark_module(mod, dims[0]); + print_avg_std_dev("JIT", jit_runtimes, dims[0][0]); +#endif + std::cout << "ok\n"; }