Skip to content

Commit

Permalink
[Pass] Simplify consecutive casts in Relay
Browse files Browse the repository at this point in the history
  • Loading branch information
icemelon committed May 19, 2021
1 parent c999a84 commit 719b151
Show file tree
Hide file tree
Showing 2 changed files with 57 additions and 5 deletions.
39 changes: 35 additions & 4 deletions src/relay/transforms/simplify_expr.cc
Original file line number Diff line number Diff line change
Expand Up @@ -78,11 +78,11 @@ class SimplifyReshape : public DFPatternRewrite {
};

/*!
* \brief SimplifyCast matches the pattern of cast data to the same dtype.
* \brief SimplifySameCast matches the pattern of cast data to the same dtype.
*/
class SimplifyCast : public DFPatternRewrite {
class SimplifySameCast : public DFPatternRewrite {
public:
SimplifyCast() {
SimplifySameCast() {
data_pat_ = IsWildcard();
like_pat_ = IsWildcard();
pattern_ = IsOp("cast_like")({data_pat_, like_pat_}) || IsOp("cast")({data_pat_});
Expand All @@ -104,6 +104,36 @@ class SimplifyCast : public DFPatternRewrite {
DFPattern like_pat_;
};

/*!
* \brief SimplifyConsecutiveCast matches the pattern of consecutive cast/cast_like ops
*/
class SimplifyConsecutiveCast : public DFPatternRewrite {
public:
SimplifyConsecutiveCast() {
data_ = IsWildcard();
auto cast1 = IsOp("cast_like")({data_, IsWildcard()}) || IsOp("cast")({data_});
pattern_ = IsOp("cast_like")({cast1, IsWildcard()}) || IsOp("cast")({cast1});
}

Expr Callback(const Expr& pre, const Expr& post,
const Map<DFPattern, Array<Expr>>& node_map) const override {
static const Op& cast_op = Op::Get("cast");
static const Op& cast_like_op = Op::Get("cast_like");
auto data = node_map[data_][0];
const CallNode* call = post.as<CallNode>();
if (call->op == cast_op) {
auto attr = call->attrs.as<CastAttrs>();
CHECK(attr);
return MakeCast(data, attr->dtype);
}
// cast_like op
return Call(cast_like_op, {data, call->args[1]}, Attrs(), {});
}

protected:
DFPattern data_;
};

/*!
* \brief SimplifyTranspose matches the pattern of consecutive transpose op,
* and merges or cancels them.
Expand Down Expand Up @@ -597,7 +627,8 @@ Expr SimplifyExpr(const Expr& expr, const IRModule& mod) {
composer.AddRewrite<EliminateIdentityRewrite>();
composer.AddRewrite<SimplifyReshape>();
composer.AddRewrite<SimplifyTranspose>();
composer.AddRewrite<SimplifyCast>();
composer.AddRewrite<SimplifySameCast>();
composer.AddRewrite<SimplifyConsecutiveCast>();
composer.AddRewrite<FullElementwise>();
return RewritePatterns(composer.MakeCallbacks(), expr, mod);
}
Expand Down
23 changes: 22 additions & 1 deletion tests/python/relay/test_pass_simplify_expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -402,7 +402,7 @@ def check(x, y=None, do_nothing=False):
check(id_op(const, x), id_op(op_like(x), x))


def test_simplify_cast():
def test_simplify_same_cast():
dtype = "int32"
data = relay.var("data", shape=(3, 4, 5), dtype=dtype)
expr1 = relay.cast(data, dtype)
Expand All @@ -416,6 +416,27 @@ def test_simplify_cast():
assert tvm.ir.structural_equal(actual2, expected)


def test_simplify_consecutive_cast():
dtype = "int32"
x = relay.var("x", shape=(3, 4, 5), dtype="int32")
y = relay.var("y", shape=(3, 4), dtype="int8")
z = relay.var("z", shape=(3,), dtype="float32")
expr1 = relay.cast(x, "int64")
expr2 = relay.cast(expr1, "int16")
expr3 = relay.cast_like(expr2, y)
expr4 = relay.cast_like(expr3, z)

actual1 = run_opt_pass(expr2, relay.transform.SimplifyExpr())
expected = run_infer_type(relay.cast(x, "int16"))
assert tvm.ir.structural_equal(actual1, expected)
actual2 = run_opt_pass(expr3, relay.transform.SimplifyExpr())
expected = run_infer_type(relay.cast_like(x, y))
assert tvm.ir.structural_equal(actual2, expected)
actual3 = run_opt_pass(expr4, relay.transform.SimplifyExpr())
expected = run_infer_type(relay.cast_like(x, z))
assert tvm.ir.structural_equal(actual3, expected)


def test_concretize_reshape_like():
data = relay.var("data", shape=(2, 3, 4), dtype="float32")
shape_like = relay.var("shape_like", shape=(6, 2, 2), dtype="float32")
Expand Down

0 comments on commit 719b151

Please sign in to comment.