From 9d1946e03b4e11b6bd89f5a8f52d39a331d47eb2 Mon Sep 17 00:00:00 2001 From: Naren Dasan Date: Sat, 13 Jun 2020 19:09:34 -0700 Subject: [PATCH] feat(//core/conversion): Compiler can now create graphs out of programs that use conditionals if it can be gaurenteed that there is a single code path followed through the course of the program given input information and the graph This means that right now conditionals within loops is not supported but if a program has a bunch of evaluatable cases and those cases produce tensors as long as the program does not need to run both branches conditionally at runtime the program can still be compiled Signed-off-by: Naren Dasan Signed-off-by: Naren Dasan --- core/conversion/conversion.cpp | 32 +++++++++++++++++++++++++------- 1 file changed, 25 insertions(+), 7 deletions(-) diff --git a/core/conversion/conversion.cpp b/core/conversion/conversion.cpp index aea53ff6b1..59faaefd9b 100644 --- a/core/conversion/conversion.cpp +++ b/core/conversion/conversion.cpp @@ -201,12 +201,27 @@ void MapIValues(ConversionCtx* ctx, c10::ArrayRef in_l }); for (auto p : input_output_pairs) { - auto input = ctx->evaluated_value_map[p.first]; - ctx->evaluated_value_map[p.second] = torch::jit::IValue(input); + if (ctx->evaluated_value_map.find(p.first) != ctx->evaluated_value_map.end()) { + auto input = ctx->evaluated_value_map[p.first]; + ctx->evaluated_value_map[p.second] = torch::jit::IValue(input); + } else if (ctx->value_tensor_map.find(p.first) != ctx->value_tensor_map.end()) { + auto input = ctx->value_tensor_map[p.first]; + ctx->value_tensor_map[p.second] = input; + } else { + TRTORCH_THROW_ERROR("Cannot find Value " << p.first->debugName() << " either evaluated values or tensor maps (MapIValues)"); + } } } -void EvaluateConditionalBlock(ConversionCtx* ctx, const torch::jit::Node* n) { +void EvaluateConditionalBlock(ConversionCtx* ctx, const torch::jit::Node* n, bool contained_in_loop = false) { + bool output_type_includes_tensor = false; + for (auto o : n->outputs()) { + if (o->type()->isSubtypeOf(c10::TensorType::get())) { + output_type_includes_tensor = true; + } + } + TRTORCH_CHECK(!(contained_in_loop && output_type_includes_tensor), "TRTorch currently cannot compile conditionals within loops"); + auto condition = ctx->evaluated_value_map[n->input(0)].toBool(); LOG_DEBUG(ctx->logger, "(Conditional Evaluation) Evaluating block " << (int) condition); auto b = condition ? n->blocks()[0] : n->blocks()[1]; @@ -215,9 +230,8 @@ void EvaluateConditionalBlock(ConversionCtx* ctx, const torch::jit::Node* n) { if (bn->kind() == torch::jit::prim::Loop) { EvaluateLoopBlock(ctx, bn); } else if (bn->kind() == torch::jit::prim::If) { - EvaluateConditionalBlock(ctx, bn); - } else { - TRTORCH_CHECK(evaluators::shouldEvalAtConversionTime(bn), "TRTorch currently can only compile conditionals that are evaluatable at conversion time but node " << *bn << " cannot be evaluated.") + EvaluateConditionalBlock(ctx, bn, contained_in_loop); + } else if (evaluators::shouldEvalAtConversionTime(bn)) { auto eval = EvaluateNode(ctx, bn); if (!eval.value().isTensor()) { LOG_DEBUG(ctx->logger, "(Conditional Evaluation) Found the value to be: " << eval.value()); @@ -225,6 +239,10 @@ void EvaluateConditionalBlock(ConversionCtx* ctx, const torch::jit::Node* n) { LOG_DEBUG(ctx->logger, "(Conditional Evaluation) Found the value to be a tensor (shape " << eval.value().toTensor().sizes() << ')'); } ctx->AssociateValueAndIValue(bn->output(0), eval.value()); + } else if (converters::node_is_convertable(bn)) { + AddLayer(ctx, bn); + } else { + TRTORCH_THROW_ERROR("TRTorch is unable to compile this conditional, a converter or evaluator is not available for node " << *bn); } } @@ -251,7 +269,7 @@ void EvaluateLoopBlock(ConversionCtx* ctx, const torch::jit::Node* n) { if (bn->kind() == torch::jit::prim::Loop) { EvaluateLoopBlock(ctx, n); } else if (bn->kind() == torch::jit::prim::If) { - EvaluateConditionalBlock(ctx, bn); + EvaluateConditionalBlock(ctx, bn, true); } else { TRTORCH_CHECK(evaluators::shouldEvalAtConversionTime(bn), "TRTorch currently can only compile loops that are evaluatable at conversion time but node " << *bn << " cannot be evaluated."); auto eval = EvaluateNode(ctx, bn);