Skip to content

Commit

Permalink
initial commit
Browse files Browse the repository at this point in the history
  • Loading branch information
AndrewZhaoLuo committed Feb 3, 2022
1 parent 8ce1b6c commit 82cfe72
Show file tree
Hide file tree
Showing 2 changed files with 90 additions and 5 deletions.
63 changes: 59 additions & 4 deletions src/relay/transforms/simplify_expr.cc
Original file line number Diff line number Diff line change
Expand Up @@ -78,11 +78,11 @@ class SimplifyReshape : public DFPatternRewrite {
};

/*!
* \brief SimplifyCast matches the pattern of cast data to the same dtype.
* \brief SimplifySameCast matches the pattern of cast data to the same dtype.
*/
class SimplifyCast : public DFPatternRewrite {
class SimplifySameCast : public DFPatternRewrite {
public:
SimplifyCast() {
SimplifySameCast() {
data_pat_ = IsWildcard();
like_pat_ = IsWildcard();
pattern_ = IsOp("cast_like")({data_pat_, like_pat_}) || IsOp("cast")({data_pat_});
Expand All @@ -104,6 +104,60 @@ class SimplifyCast : public DFPatternRewrite {
DFPattern like_pat_;
};

/*!
* \brief SimplifyConsecutiveCast matches the pattern of consecutive cast/cast_like ops
*/
class SimplifyConsecutiveCast : public DFPatternRewrite {
public:
SimplifyConsecutiveCast() {
data_ = IsWildcard();
cast1_ = IsOp("cast_like")({data_, IsWildcard()}) || IsOp("cast")({data_});
pattern_ = IsOp("cast_like")({cast1_, IsWildcard()}) || IsOp("cast")({cast1_});
}

Expr Callback(const Expr& pre, const Expr& post,
const Map<DFPattern, Array<Expr>>& node_map) const override {
auto data = node_map[data_][0];
auto cast1 = Downcast<Call>(node_map[cast1_][0]);
auto data_type = Downcast<TensorType>(data->checked_type());
DataType cast1_dtype = Downcast<TensorType>(cast1->checked_type())->dtype;

if (!IsWidenCast(data_type->dtype, cast1_dtype)) {
// Cannot remove the narrow cast
return post;
}

const CallNode* cast2 = post.as<CallNode>();
DataType cast2_dtype = Downcast<TensorType>(cast2->checked_type())->dtype;
auto expr = MakeCast(data, cast2_dtype);

// We need to set the checked type as it may be needed in the next callback
expr->checked_type_ = TensorType(data_type->shape, cast2_dtype);
return expr;
}

bool IsWidenCast(DataType origin, DataType cast) const {
/* Return whether casting from origin to cast results in more or the same precision.*/
if (origin.code() == cast.code() && origin.bits() <= cast.bits()) {
return true;
}
if (origin.code() == DataType::kBFloat || cast.code() == DataType::kBFloat) {
// BFloat cast cannot be omitted
return false;
}
if (origin.code() < cast.code()) {
// Loosely have a hiearchy to datatypes
// e.g. int --> uint --> float has increasing range of numbers they can represent
return true;
}
return false;
}

protected:
DFPattern data_;
DFPattern cast1_;
};

/*!
* \brief SimplifyTranspose matches the pattern of consecutive transpose op,
* and merges or cancels them.
Expand Down Expand Up @@ -640,7 +694,8 @@ Expr SimplifyExpr(const Expr& expr, const IRModule& mod) {
composer.AddRewrite<EliminateIdentityRewrite>();
composer.AddRewrite<SimplifyReshape>();
composer.AddRewrite<SimplifyTranspose>();
composer.AddRewrite<SimplifyCast>();
composer.AddRewrite<SimplifySameCast>();
composer.AddRewrite<SimplifyConsecutiveCast>();
composer.AddRewrite<FullElementwise>();
composer.AddRewrite<SimplifyConsecutiveAdd>();
return RewritePatterns(composer.MakeCallbacks(), expr, mod);
Expand Down
32 changes: 31 additions & 1 deletion tests/python/relay/test_pass_simplify_expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -402,7 +402,7 @@ def check(x, y=None, do_nothing=False):
check(id_op(const, x), id_op(op_like(x), x))


def test_simplify_cast():
def test_simplify_same_cast():
dtype = "int32"
data = relay.var("data", shape=(3, 4, 5), dtype=dtype)
expr1 = relay.cast(data, dtype)
Expand All @@ -416,6 +416,36 @@ def test_simplify_cast():
assert tvm.ir.structural_equal(actual2, expected)


def test_simplify_consecutive_cast():
x = relay.var("x", shape=(3, 4, 5), dtype="int8")
y = relay.var("y", shape=(3, 4), dtype="int64")
z = relay.var("z", shape=(3,), dtype="float32")

expr1 = relay.cast(x, "int16")
expr2 = relay.cast(expr1, "int32")
expr3 = relay.cast_like(expr2, y)
expr4 = relay.cast_like(expr3, z)

actual1 = run_opt_pass(expr2, relay.transform.SimplifyExpr())
expected = run_infer_type(relay.cast(x, "int32"))
assert tvm.ir.structural_equal(actual1, expected)
actual2 = run_opt_pass(expr3, relay.transform.SimplifyExpr())
expected = run_infer_type(relay.cast(x, "int64"))
assert tvm.ir.structural_equal(actual2, expected)
actual3 = run_opt_pass(expr4, relay.transform.SimplifyExpr())
expected = run_infer_type(relay.cast(x, "float32"))
assert tvm.ir.structural_equal(actual3, expected)

# cannot simplify the narrow cast
x = relay.var("x", shape=(3, 4, 5), dtype="float32")
y = relay.var("y", shape=(3, 4), dtype="float32")
expr1 = relay.cast(x, "int32")
expr2 = relay.cast_like(expr1, y)
actual = run_opt_pass(expr2, relay.transform.SimplifyExpr())
expected = run_infer_type(expr2)
assert tvm.ir.structural_equal(actual, expected)


def test_concretize_reshape_like():
data = relay.var("data", shape=(2, 3, 4), dtype="float32")
shape_like = relay.var("shape_like", shape=(6, 2, 2), dtype="float32")
Expand Down

0 comments on commit 82cfe72

Please sign in to comment.