From 6e0da4d45fab29944b06ee6be45c64d5bcbd8141 Mon Sep 17 00:00:00 2001 From: Bo Wang Date: Wed, 5 May 2021 21:32:37 -0500 Subject: [PATCH] chore: update doc for partitioning Signed-off-by: Bo Wang --- core/partitioning/README.md | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/core/partitioning/README.md b/core/partitioning/README.md index 938d0f2353..9e7b4d74c0 100644 --- a/core/partitioning/README.md +++ b/core/partitioning/README.md @@ -1,6 +1,6 @@ # TRTorch Partitioning -TRTorch partitioning phase is developed to support automatic fallback feature in TRTorch. This phase won't run by +TRTorch partitioning phase is developed to support `automatic fallback` feature in TRTorch. This phase won't run by default until the automatic fallback feature is enabled. On a high level, TRTorch partitioning phase does the following: @@ -15,6 +15,8 @@ from the user. Shapes can be calculated by running the graphs with JIT. it's still a phase in our partitioning process. - `Stitching`. Stitch all TensorRT engines with PyTorch nodes altogether. +Test cases for each of these components could be found [here](https://github.com/NVIDIA/TRTorch/tree/master/tests/core/partitioning). + Here is the brief description of functionalities of each file: - `PartitionInfo.h/cpp`: The automatic fallback APIs that is used for partitioning. - `SegmentedBlock.h/cpp`: The main data structures that is used to maintain information for each segments after segmentation. @@ -34,8 +36,8 @@ To enable automatic fallback feature, you can set following attributes in Python ... "torch_fallback" : { "enabled" : True, - "min_block_size" : 1, - "forced_fallback_ops": ["aten::foo"], + "min_block_size" : 3, + "forced_fallback_ops": ["aten::add"], } }) ``` @@ -58,8 +60,8 @@ auto mod = trtorch::jit::load("trt_ts_module.ts"); auto input_sizes = std::vector{{in.sizes()}}; trtorch::CompileSpec cfg(input_sizes); cfg.torch_fallback = trtorch::CompileSpec::TorchFallback(true); -cfg.torch_fallback.min_block_size = 1; -cfg.torch_fallback.forced_fallback_ops.push_back("aten::foo"); +cfg.torch_fallback.min_block_size = 2; +cfg.torch_fallback.forced_fallback_ops.push_back("aten::relu"); auto trt_mod = trtorch::CompileGraph(mod, cfg); auto out = trt_mod.forward({in}); ```