Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FoldConstant] Create Interpreter for each constant subgraph #6195

Merged
merged 1 commit into from
Aug 3, 2020
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 20 additions & 16 deletions src/relay/transforms/fold_constant.cc
Original file line number Diff line number Diff line change
Expand Up @@ -77,9 +77,8 @@ TVM_REGISTER_GLOBAL("relay.analysis.check_constant").set_body_typed(ConstantChec
// or make a more powerful partial evaluator.
class ConstantFolder : public ExprMutator {
public:
explicit ConstantFolder(FInterpreter executor, IRModule module)
: executor_(executor),
module_(module),
explicit ConstantFolder(IRModule module)
: module_(module),
shape_of_op_(Op::Get("shape_of")),
vm_shape_of_op_(Op::Get("vm.shape_of")),
invoke_tvm_op_(Op::Get("vm.invoke_tvm_op")),
Expand Down Expand Up @@ -163,8 +162,6 @@ class ConstantFolder : public ExprMutator {
}

private:
// Internal interepreter.
FInterpreter executor_;
// Internal constant checker
ConstantChecker checker_;
// Module
Expand All @@ -180,6 +177,20 @@ class ConstantFolder : public ExprMutator {
const Op& cast_op_;
const Op& ndarray_size_op_;

// Create an interpreter.
FInterpreter GetInterpreter(const IRModule& mod) {
using tvm::transform::PassContext;
DLContext ctx;
ctx.device_type = kDLCPU;
ctx.device_id = 0;
Target target = Target::Create("llvm");
// use a fresh build context
// in case we are already in a build context.
With<PassContext> fresh_build_ctx(PassContext::Create());

return CreateInterpreter(mod, ctx, target);
}

// Convert value to expression.
Expr ObjectToExpr(const ObjectRef& value) {
if (value->IsInstance<runtime::NDArray::ContainerType>()) {
Expand Down Expand Up @@ -218,7 +229,9 @@ class ConstantFolder : public ExprMutator {
mod = seq(mod);
auto entry_func = Downcast<Function>(mod->Lookup("main"));
expr = expr.as<FunctionNode>() == nullptr ? entry_func->body : entry_func;
return ObjectToExpr(executor_(expr));

FInterpreter executor = GetInterpreter(mod);
return ObjectToExpr(executor(expr));
}

// Evaluate a call to the shape_of operator for tensors with constant
Expand Down Expand Up @@ -331,16 +344,7 @@ class ConstantFolder : public ExprMutator {
};

Expr FoldConstant(const Expr& expr, const IRModule& mod) {
using tvm::transform::PassContext;
DLContext ctx;
ctx.device_type = kDLCPU;
ctx.device_id = 0;
Target target = Target::Create("llvm");
// use a fresh build context
// in case we are already in a build context.
With<PassContext> fresh_build_ctx(PassContext::Create());

return ConstantFolder(CreateInterpreter(mod, ctx, target), mod).Mutate(expr);
return ConstantFolder(mod).Mutate(expr);
}

namespace transform {
Expand Down