Skip to content

Commit

Permalink
fix: Add test case, move config condition
Browse files Browse the repository at this point in the history
- Add test case to elicit behavior where full compilation is requested
but TRT engine size falls below default `min_block_size=3`
- Move `min_block_size` condition to narrow scope
  • Loading branch information
gs-olive committed Feb 1, 2023
1 parent 0e670d5 commit 1209225
Show file tree
Hide file tree
Showing 4 changed files with 63 additions and 21 deletions.
35 changes: 16 additions & 19 deletions core/compiler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -143,19 +143,14 @@ partitioning::GraphAndMapping BuildHybridGraph(
auto convert_info = cfg.convert_info;
auto partitioning_info = cfg.partitioning_info;

// Any nonzero block size is valid if full compilation to TRT is desired
if (expect_full_compilation) {
partitioning_info.min_block_size = 1;
}

auto partitioning_ctx = partitioning::PartitioningCtx(block, partitioning_info);
partitioning_ctx.input_types_map = first_use_types;

// Generate a dictionary of input torch::jit::Value's to their min, opt, max tensors and store in ctx
// TODO: Combine this within partition call
partitioning::populateInputIValues(&partitioning_ctx);

partitioning::partition(&partitioning_ctx);
partitioning::partition(&partitioning_ctx, expect_full_compilation);

for (auto& partitioned_block : partitioning_ctx.partitioned_blocks) {
partitioning::PartitionedGraph& segmented_blocks = partitioned_block.second;
Expand Down Expand Up @@ -197,9 +192,11 @@ partitioning::GraphAndMapping BuildHybridGraph(
if (expect_full_compilation) {
for (auto torch_node : seg_block.block()->nodes()) {
if (partitioning::CollectionNodeKinds.find(torch_node->kind()) == partitioning::CollectionNodeKinds.end()) {
LOG_ERROR(
"Full compilation specified but node " << torch_node->kind().toQualString()
<< " was executed in Torch.");
TORCHTRT_THROW_ERROR(
"Full compilation specified but node "
<< *torch_node
<< " is set to run in PyTorch due to either lack of support in TensorRT or graph partitioning rules."
<< " Try recompiling with require_full_compilation=False.");
}
}
}
Expand All @@ -209,10 +206,9 @@ partitioning::GraphAndMapping BuildHybridGraph(
// If full compilation is expected, cannot have more than 2 Torch segments
// (one for preprocessing inputs, one for post-processing outputs) and 1 TRT segment
if (expect_full_compilation && !(num_torch_segments <= 2 && num_trt_segments == 1)) {
LOG_ERROR(
"Full compilation specified but number of torch segments was "
<< num_torch_segments << " and number of trt segments was " << num_trt_segments
<< ". Was expecting at most 2 Torch segments and 1 TRT segment.");
TORCHTRT_THROW_ERROR(
"Full compilation was requested but unable to convert all operations to TensorRT."
<< " Try recompiling with require_full_compilation=False.");
}
}

Expand All @@ -224,7 +220,7 @@ ir::TypeMap MapInputsAndDetermineDTypes(
std::shared_ptr<torch::jit::Graph>& g,
ir::StaticParams& static_params,
ir::CollectionTypeMap& first_use_type_map,
bool expect_full_compilation = false) {
bool requires_collection_handling = false) {
cfg.convert_info.collection_input_spec_map =
std::move(ir::associate_specs_with_collection_inputs(g, cfg.graph_inputs, static_params));
cfg.partitioning_info.collection_input_spec_map =
Expand Down Expand Up @@ -259,7 +255,7 @@ ir::TypeMap MapInputsAndDetermineDTypes(
"Cannot infer input type from calcuations in graph for input "
<< in->debugName() << ". Assuming it is Float32. If not, specify input type explicity");
spec[i].dtype = at::kFloat;
} else if (spec[i].dtype_is_user_defined && (cfg.partitioning_info.enabled || expect_full_compilation)) {
} else if (spec[i].dtype_is_user_defined && (cfg.partitioning_info.enabled || requires_collection_handling)) {
if (!est_type_opt[i]) {
LOG_INFO("Cannot infer input tensor dtype in graph, compiler is going to use the user setting");
std::stringstream ss;
Expand Down Expand Up @@ -352,10 +348,10 @@ torch::jit::Module CompileGraph(const torch::jit::Module& mod, CompileSpec cfg)
// whether full compilation can be expected
auto isBlockConvertible = conversion::VerifyConverterSupportForBlock(g->block(), true);
auto outputIsCollection = conversion::OutputIsCollection(g->block());
auto nearly_full_compilation = (isBlockConvertible && outputIsCollection);
auto requires_collection_handling = (isBlockConvertible && outputIsCollection);

// Extract map of IValue to DType
auto type_map = MapInputsAndDetermineDTypes(cfg, g, static_params, first_use_types, nearly_full_compilation);
auto type_map = MapInputsAndDetermineDTypes(cfg, g, static_params, first_use_types, requires_collection_handling);

// Check whether any of the input types are Long
bool user_requested_long = false;
Expand All @@ -380,10 +376,11 @@ torch::jit::Module CompileGraph(const torch::jit::Module& mod, CompileSpec cfg)
(!(cfg.lower_info.forced_fallback_modules.size() == 0 &&
cfg.partitioning_info.forced_fallback_operators.size() == 0 && isBlockConvertible) ||
outputIsCollection || user_requested_long)) ||
nearly_full_compilation) {
requires_collection_handling) {
// If the model is fully-compilable and the user has specified full compilation, run partitioning
// to generate collection-processing code in Torch
auto expect_full_compilation = (nearly_full_compilation && !cfg.partitioning_info.enabled);
auto expect_full_compilation = (requires_collection_handling && !cfg.partitioning_info.enabled);

auto graph_and_mapping =
BuildHybridGraph(new_mod, g->block(), cfg, static_params, first_use_types, expect_full_compilation);
new_g = graph_and_mapping.first;
Expand Down
16 changes: 15 additions & 1 deletion core/partitioning/partitioning.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -564,7 +564,21 @@ void populateInputIValues(PartitioningCtx* ctx) {
}
}

void partition(PartitioningCtx* ctx) {
void partition(PartitioningCtx* ctx, bool expect_full_compilation) {
// If full compilation is expected, overwrite minimum block size
// Any nonzero block size is valid if full compilation to TRT is desired
// Override the default min_block_size to ensure all TRT-supported operations are
// executed in TRT, regardless of the size of the graph
if (expect_full_compilation) {
// If minimum block size is different from the default, the user must have specified it
if (ctx->settings.min_block_size != 3) {
LOG_WARNING(
"Detected user-specified min_block_size with require_full_compilation=True "
<< "disregarding min_block_size.");
}
ctx->settings.min_block_size = 1;
}

LOG_DEBUG(ctx->settings);

// Go through all the blocks to do the partitioning
Expand Down
2 changes: 1 addition & 1 deletion core/partitioning/partitioning.h
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ void segmentGraph(PartitioningCtx* ctx, torch::jit::Block* block);

GraphAndMapping stitch(PartitioningCtx* ctx, torch::jit::Block* block);

void partition(PartitioningCtx* ctx);
void partition(PartitioningCtx* ctx, bool expect_full_compilation = false);

} // namespace partitioning
} // namespace core
Expand Down
31 changes: 31 additions & 0 deletions tests/py/api/test_e2e_behavior.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,37 @@ def forward(self, x, y, z):
trt_output, torch_output
), "Found differing output formatting between Torch-TRT and Torch"

def test_tuple_output_with_full_compilation(self):
class Sample(torch.nn.Module):
def __init__(self):
super(Sample, self).__init__()

def forward(self, x, y):
a = x + y
return (a,)

self.model = Sample().eval().to("cuda")
self.input_1 = torch.zeros((5, 5), dtype=torch.float, device="cuda:0")
self.input_2 = torch.ones((5, 5), dtype=torch.float, device="cuda:0")
scripted_mod = torch.jit.script(self.model)

inputs = [
torchtrt.Input((5, 5), dtype=torch.float),
torchtrt.Input((5, 5), dtype=torch.float),
]

trt_mod = torchtrt.ts.compile(
scripted_mod,
inputs=inputs,
require_full_compilation=True,
enabled_precisions={torch.float, torch.half},
)
trt_output = trt_mod(self.input_1, self.input_2)
torch_output = self.model(self.input_1, self.input_2)
assert same_output_format(
trt_output, torch_output
), "Found differing output formatting between Torch-TRT and Torch"


if __name__ == "__main__":
unittest.main()

0 comments on commit 1209225

Please sign in to comment.