diff --git a/src/relay/transforms/fold_constant.cc b/src/relay/transforms/fold_constant.cc index 48af31f9a11f..66f233bbba85 100644 --- a/src/relay/transforms/fold_constant.cc +++ b/src/relay/transforms/fold_constant.cc @@ -120,6 +120,18 @@ class ConstantFolder : public MixedModeMutator { } } + Expr VisitExpr_(const IfNode* op) final { + auto new_cond = ExprMutator::VisitExpr(op->cond); + if (auto const_cond = new_cond.as()) { + if (reinterpret_cast(const_cond->data->data)[0]) { + return ExprMutator::VisitExpr(op->true_branch); + } else { + return ExprMutator::VisitExpr(op->false_branch); + } + } + return ExprMutator::VisitExpr_(op); + } + Expr Rewrite_(const CallNode* call, const Expr& post) final { if (inside_primitive) { return GetRef(call); diff --git a/tests/python/relay/test_pass_fold_constant.py b/tests/python/relay/test_pass_fold_constant.py index 549596d61693..76182d2c3e08 100644 --- a/tests/python/relay/test_pass_fold_constant.py +++ b/tests/python/relay/test_pass_fold_constant.py @@ -147,6 +147,45 @@ def expected(): assert tvm.ir.structural_equal(zz, zexpected) +def test_fold_if(): + cond_data = np.array(1).astype("bool") + x_data = np.array([[1, 2, 3]]).astype("float32") + + def before(): + a = relay.const(cond_data) + x = relay.const(x_data) + y = relay.const(x_data) + iff = relay.If(a, x + y, x - y) + return relay.Function([], iff) + + def expected(): + y_data = x_data + x_data + y = relay.const(y_data) + return relay.Function([], y) + + zz = run_opt_pass(before(), transform.FoldConstant()) + zexpected = run_opt_pass(expected(), transform.InferType()) + assert tvm.ir.structural_equal(zz, zexpected) + + cond_data = np.array(0).astype("bool") + + def before(): + a = relay.const(cond_data) + x = relay.const(x_data) + y = relay.const(x_data) + iff = relay.If(a, x + y, x - y) + return relay.Function([], iff) + + def expected(): + y_data = x_data - x_data + y = relay.const(y_data) + return relay.Function([], y) + + zz = run_opt_pass(before(), transform.FoldConstant()) + zexpected = run_opt_pass(expected(), transform.InferType()) + assert tvm.ir.structural_equal(zz, zexpected) + + def test_fold_shape_of(): c_shape = (8, 9, 10)