diff --git a/src/relay/transforms/fold_constant.cc b/src/relay/transforms/fold_constant.cc index a078cabda3f67..c1bbbb331139a 100644 --- a/src/relay/transforms/fold_constant.cc +++ b/src/relay/transforms/fold_constant.cc @@ -23,6 +23,7 @@ #include #include #include +#include #include #include #include @@ -254,8 +255,13 @@ class ConstantFolder : public MixedModeMutator { // needed for both execution and creation(due to JIT) With fresh_build_ctx(transform::PassContext::Create()); - Map dict = - (module_->attrs.defined()) ? module_->attrs->dict : Map(); + Map dict = (module_->attrs.defined()) + ? Map(module_->attrs.CopyOnWrite()->dict) + : Map(); + + // always use graph executor with no link-params + dict.Set(tvm::attr::kExecutor, + relay::Executor::Create("graph", {{"link-params", Bool(false)}})); Expr result = ObjectToExpr(Eval(expr, module_->type_definitions, module_->Imports(), eval_cpu_dev_, eval_cpu_target_, dict)); VLOG(1) << "Evaluated to constant:" << std::endl << PrettyPrint(result); diff --git a/tests/python/relay/test_pass_fold_constant.py b/tests/python/relay/test_pass_fold_constant.py index c165d140b1a68..298c4f177fd1b 100644 --- a/tests/python/relay/test_pass_fold_constant.py +++ b/tests/python/relay/test_pass_fold_constant.py @@ -17,6 +17,7 @@ import numpy as np import tvm from tvm import relay +from tvm.relay.backend import Executor from tvm.relay import transform from tvm.relay.build_module import bind_params_by_name from tvm.relay.testing import run_infer_type, create_workload @@ -369,6 +370,24 @@ def before(): tvm.ir.assert_structural_equal(run_infer_type(before_mod["main"]), after_mod["main"]) +def test_pass_link_params(): + """ + This test checks ensures that proper executor is passed to interpreter instance + The test will fail if FoldConstant does not override the executor due to "int8" + is not supported in ScheduleBuilder + """ + + def expr(): + z = relay.const(10, dtype="int8") + return relay.cast(z, dtype="int32") + + mod = tvm.IRModule.from_expr(expr()) + mod = tvm.relay.transform.InferType()(mod) + # Add executor with link-params + mod = mod.with_attr("executor", Executor("aot", {"link-params": True})) + mod = tvm.relay.transform.FoldConstant()(mod) + + if __name__ == "__main__": import sys import pytest