Skip to content

Commit

Permalink
feat(//core/conversion): Compiler can now create graphs
Browse files Browse the repository at this point in the history
out of programs that use conditionals if it can be gaurenteed that
there is a single code path followed through the course of the
program given input information and the graph

This means that right now conditionals within loops is not supported
but if a program has a bunch of evaluatable cases and those cases
produce tensors as long as the program does not need to run both
branches conditionally at runtime the program can still be compiled

Signed-off-by: Naren Dasan <naren@narendasan.com>
Signed-off-by: Naren Dasan <narens@nvidia.com>
  • Loading branch information
narendasan committed Jun 14, 2020
1 parent 07ba980 commit 9d1946e
Showing 1 changed file with 25 additions and 7 deletions.
32 changes: 25 additions & 7 deletions core/conversion/conversion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -201,12 +201,27 @@ void MapIValues(ConversionCtx* ctx, c10::ArrayRef<const torch::jit::Value*> in_l
});

for (auto p : input_output_pairs) {
auto input = ctx->evaluated_value_map[p.first];
ctx->evaluated_value_map[p.second] = torch::jit::IValue(input);
if (ctx->evaluated_value_map.find(p.first) != ctx->evaluated_value_map.end()) {
auto input = ctx->evaluated_value_map[p.first];
ctx->evaluated_value_map[p.second] = torch::jit::IValue(input);
} else if (ctx->value_tensor_map.find(p.first) != ctx->value_tensor_map.end()) {
auto input = ctx->value_tensor_map[p.first];
ctx->value_tensor_map[p.second] = input;
} else {
TRTORCH_THROW_ERROR("Cannot find Value " << p.first->debugName() << " either evaluated values or tensor maps (MapIValues)");
}
}
}

void EvaluateConditionalBlock(ConversionCtx* ctx, const torch::jit::Node* n) {
void EvaluateConditionalBlock(ConversionCtx* ctx, const torch::jit::Node* n, bool contained_in_loop = false) {
bool output_type_includes_tensor = false;
for (auto o : n->outputs()) {
if (o->type()->isSubtypeOf(c10::TensorType::get())) {
output_type_includes_tensor = true;
}
}
TRTORCH_CHECK(!(contained_in_loop && output_type_includes_tensor), "TRTorch currently cannot compile conditionals within loops");

auto condition = ctx->evaluated_value_map[n->input(0)].toBool();
LOG_DEBUG(ctx->logger, "(Conditional Evaluation) Evaluating block " << (int) condition);
auto b = condition ? n->blocks()[0] : n->blocks()[1];
Expand All @@ -215,16 +230,19 @@ void EvaluateConditionalBlock(ConversionCtx* ctx, const torch::jit::Node* n) {
if (bn->kind() == torch::jit::prim::Loop) {
EvaluateLoopBlock(ctx, bn);
} else if (bn->kind() == torch::jit::prim::If) {
EvaluateConditionalBlock(ctx, bn);
} else {
TRTORCH_CHECK(evaluators::shouldEvalAtConversionTime(bn), "TRTorch currently can only compile conditionals that are evaluatable at conversion time but node " << *bn << " cannot be evaluated.")
EvaluateConditionalBlock(ctx, bn, contained_in_loop);
} else if (evaluators::shouldEvalAtConversionTime(bn)) {
auto eval = EvaluateNode(ctx, bn);
if (!eval.value().isTensor()) {
LOG_DEBUG(ctx->logger, "(Conditional Evaluation) Found the value to be: " << eval.value());
} else {
LOG_DEBUG(ctx->logger, "(Conditional Evaluation) Found the value to be a tensor (shape " << eval.value().toTensor().sizes() << ')');
}
ctx->AssociateValueAndIValue(bn->output(0), eval.value());
} else if (converters::node_is_convertable(bn)) {
AddLayer(ctx, bn);
} else {
TRTORCH_THROW_ERROR("TRTorch is unable to compile this conditional, a converter or evaluator is not available for node " << *bn);
}
}

Expand All @@ -251,7 +269,7 @@ void EvaluateLoopBlock(ConversionCtx* ctx, const torch::jit::Node* n) {
if (bn->kind() == torch::jit::prim::Loop) {
EvaluateLoopBlock(ctx, n);
} else if (bn->kind() == torch::jit::prim::If) {
EvaluateConditionalBlock(ctx, bn);
EvaluateConditionalBlock(ctx, bn, true);
} else {
TRTORCH_CHECK(evaluators::shouldEvalAtConversionTime(bn), "TRTorch currently can only compile loops that are evaluatable at conversion time but node " << *bn << " cannot be evaluated.");
auto eval = EvaluateNode(ctx, bn);
Expand Down

0 comments on commit 9d1946e

Please sign in to comment.