Skip to content

Commit

Permalink
[FoldConstant] Create Interpreter for each constant subgraph (apache#…
Browse files Browse the repository at this point in the history
  • Loading branch information
anijain2305 authored and Trevor Morris committed Aug 26, 2020
1 parent c374ce6 commit 5ff7a8e
Showing 1 changed file with 20 additions and 16 deletions.
36 changes: 20 additions & 16 deletions src/relay/transforms/fold_constant.cc
Original file line number Diff line number Diff line change
Expand Up @@ -77,9 +77,8 @@ TVM_REGISTER_GLOBAL("relay.analysis.check_constant").set_body_typed(ConstantChec
// or make a more powerful partial evaluator.
class ConstantFolder : public ExprMutator {
public:
explicit ConstantFolder(FInterpreter executor, IRModule module)
: executor_(executor),
module_(module),
explicit ConstantFolder(IRModule module)
: module_(module),
shape_of_op_(Op::Get("shape_of")),
vm_shape_of_op_(Op::Get("vm.shape_of")),
invoke_tvm_op_(Op::Get("vm.invoke_tvm_op")),
Expand Down Expand Up @@ -163,8 +162,6 @@ class ConstantFolder : public ExprMutator {
}

private:
// Internal interepreter.
FInterpreter executor_;
// Internal constant checker
ConstantChecker checker_;
// Module
Expand All @@ -180,6 +177,20 @@ class ConstantFolder : public ExprMutator {
const Op& cast_op_;
const Op& ndarray_size_op_;

// Create an interpreter.
FInterpreter GetInterpreter(const IRModule& mod) {
using tvm::transform::PassContext;
DLContext ctx;
ctx.device_type = kDLCPU;
ctx.device_id = 0;
Target target = Target::Create("llvm");
// use a fresh build context
// in case we are already in a build context.
With<PassContext> fresh_build_ctx(PassContext::Create());

return CreateInterpreter(mod, ctx, target);
}

// Convert value to expression.
Expr ObjectToExpr(const ObjectRef& value) {
if (value->IsInstance<runtime::NDArray::ContainerType>()) {
Expand Down Expand Up @@ -218,7 +229,9 @@ class ConstantFolder : public ExprMutator {
mod = seq(mod);
auto entry_func = Downcast<Function>(mod->Lookup("main"));
expr = expr.as<FunctionNode>() == nullptr ? entry_func->body : entry_func;
return ObjectToExpr(executor_(expr));

FInterpreter executor = GetInterpreter(mod);
return ObjectToExpr(executor(expr));
}

// Evaluate a call to the shape_of operator for tensors with constant
Expand Down Expand Up @@ -331,16 +344,7 @@ class ConstantFolder : public ExprMutator {
};

Expr FoldConstant(const Expr& expr, const IRModule& mod) {
using tvm::transform::PassContext;
DLContext ctx;
ctx.device_type = kDLCPU;
ctx.device_id = 0;
Target target = Target::Create("llvm");
// use a fresh build context
// in case we are already in a build context.
With<PassContext> fresh_build_ctx(PassContext::Create());

return ConstantFolder(CreateInterpreter(mod, ctx, target), mod).Mutate(expr);
return ConstantFolder(mod).Mutate(expr);
}

namespace transform {
Expand Down

0 comments on commit 5ff7a8e

Please sign in to comment.