Skip to content

Commit

Permalink
chore: update doc for partitioning
Browse files Browse the repository at this point in the history
Signed-off-by: Bo Wang <wangbo1995ee@163.com>
  • Loading branch information
bowang007 committed May 6, 2021
1 parent 78e67cc commit 6e0da4d
Showing 1 changed file with 7 additions and 5 deletions.
12 changes: 7 additions & 5 deletions core/partitioning/README.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# TRTorch Partitioning

TRTorch partitioning phase is developed to support automatic fallback feature in TRTorch. This phase won't run by
TRTorch partitioning phase is developed to support `automatic fallback` feature in TRTorch. This phase won't run by
default until the automatic fallback feature is enabled.

On a high level, TRTorch partitioning phase does the following:
Expand All @@ -15,6 +15,8 @@ from the user. Shapes can be calculated by running the graphs with JIT.
it's still a phase in our partitioning process.
- `Stitching`. Stitch all TensorRT engines with PyTorch nodes altogether.

Test cases for each of these components could be found [here](https://github.com/NVIDIA/TRTorch/tree/master/tests/core/partitioning).

Here is the brief description of functionalities of each file:
- `PartitionInfo.h/cpp`: The automatic fallback APIs that is used for partitioning.
- `SegmentedBlock.h/cpp`: The main data structures that is used to maintain information for each segments after segmentation.
Expand All @@ -34,8 +36,8 @@ To enable automatic fallback feature, you can set following attributes in Python
...
"torch_fallback" : {
"enabled" : True,
"min_block_size" : 1,
"forced_fallback_ops": ["aten::foo"],
"min_block_size" : 3,
"forced_fallback_ops": ["aten::add"],
}
})
```
Expand All @@ -58,8 +60,8 @@ auto mod = trtorch::jit::load("trt_ts_module.ts");
auto input_sizes = std::vector<trtorch::CompileSpec::InputRange>{{in.sizes()}};
trtorch::CompileSpec cfg(input_sizes);
cfg.torch_fallback = trtorch::CompileSpec::TorchFallback(true);
cfg.torch_fallback.min_block_size = 1;
cfg.torch_fallback.forced_fallback_ops.push_back("aten::foo");
cfg.torch_fallback.min_block_size = 2;
cfg.torch_fallback.forced_fallback_ops.push_back("aten::relu");
auto trt_mod = trtorch::CompileGraph(mod, cfg);
auto out = trt_mod.forward({in});
```

0 comments on commit 6e0da4d

Please sign in to comment.