Skip to content

Commit

Permalink
Fold If when the condition is Constant (apache#7354)
Browse files Browse the repository at this point in the history
  • Loading branch information
Matthew Brookhart authored and alexwong committed Feb 11, 2021
1 parent 38851c2 commit 70efe33
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 0 deletions.
12 changes: 12 additions & 0 deletions src/relay/transforms/fold_constant.cc
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,18 @@ class ConstantFolder : public MixedModeMutator {
}
}

Expr VisitExpr_(const IfNode* op) final {
auto new_cond = ExprMutator::VisitExpr(op->cond);
if (auto const_cond = new_cond.as<ConstantNode>()) {
if (reinterpret_cast<uint8_t*>(const_cond->data->data)[0]) {
return ExprMutator::VisitExpr(op->true_branch);
} else {
return ExprMutator::VisitExpr(op->false_branch);
}
}
return ExprMutator::VisitExpr_(op);
}

Expr Rewrite_(const CallNode* call, const Expr& post) final {
if (inside_primitive) {
return GetRef<Expr>(call);
Expand Down
39 changes: 39 additions & 0 deletions tests/python/relay/test_pass_fold_constant.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,45 @@ def expected():
assert tvm.ir.structural_equal(zz, zexpected)


def test_fold_if():
cond_data = np.array(1).astype("bool")
x_data = np.array([[1, 2, 3]]).astype("float32")

def before():
a = relay.const(cond_data)
x = relay.const(x_data)
y = relay.const(x_data)
iff = relay.If(a, x + y, x - y)
return relay.Function([], iff)

def expected():
y_data = x_data + x_data
y = relay.const(y_data)
return relay.Function([], y)

zz = run_opt_pass(before(), transform.FoldConstant())
zexpected = run_opt_pass(expected(), transform.InferType())
assert tvm.ir.structural_equal(zz, zexpected)

cond_data = np.array(0).astype("bool")

def before():
a = relay.const(cond_data)
x = relay.const(x_data)
y = relay.const(x_data)
iff = relay.If(a, x + y, x - y)
return relay.Function([], iff)

def expected():
y_data = x_data - x_data
y = relay.const(y_data)
return relay.Function([], y)

zz = run_opt_pass(before(), transform.FoldConstant())
zexpected = run_opt_pass(expected(), transform.InferType())
assert tvm.ir.structural_equal(zz, zexpected)


def test_fold_shape_of():
c_shape = (8, 9, 10)

Expand Down

0 comments on commit 70efe33

Please sign in to comment.