Skip to content

Commit

Permalink
cinn(py-dsl): refactore code
Browse files Browse the repository at this point in the history
  • Loading branch information
6clc committed Sep 21, 2023
1 parent 9d16ac8 commit 13395b1
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 14 deletions.
14 changes: 7 additions & 7 deletions paddle/cinn/ir/utils/ir_compare.cc
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ bool IrEqualVisitor::Compare(const Expr& lhs, const Expr& rhs) {
return true;
}

if (only_compare_sturcture_ && !lhs.defined() && !rhs.defined()) {
if (only_compare_structure_ && !lhs.defined() && !rhs.defined()) {
return true;
}

Expand Down Expand Up @@ -188,7 +188,7 @@ bool IrEqualVisitor::Visit(const Call* lhs, const Expr* other) {
Compare(lhs->write_args, rhs->write_args) &&
Compare(lhs->attrs, rhs->attrs) &&
lhs->call_type == rhs->call_type;
if (only_compare_sturcture_) {
if (only_compare_structure_) {
return flag;
}
return lhs->name == rhs->name && flag;
Expand All @@ -200,7 +200,7 @@ bool IrEqualVisitor::Visit(const _Var_* lhs, const Expr* other) {
bool flag = Compare(lhs->lower_bound, rhs->lower_bound) &&
Compare(lhs->upper_bound, rhs->upper_bound) &&
lhs->tag == rhs->tag;
if (only_compare_sturcture_) {
if (only_compare_structure_) {
return flag;
}
return lhs->name == rhs->name && flag;
Expand Down Expand Up @@ -239,7 +239,7 @@ bool IrEqualVisitor::Visit(const _Buffer_* lhs, const Expr* other) {
lhs->offset_factor == rhs->offset_factor && lhs->target == rhs->target &&
lhs->data_alignment == rhs->data_alignment &&
lhs->memory_type == rhs->memory_type && lhs->dtype == rhs->dtype;
if (only_compare_sturcture_) {
if (only_compare_structure_) {
return flag;
}
return flag && lhs->name == rhs->name;
Expand All @@ -248,7 +248,7 @@ bool IrEqualVisitor::Visit(const _Buffer_* lhs, const Expr* other) {
bool IrEqualVisitor::Visit(const _Tensor_* lhs, const Expr* other) {
auto* rhs = other->As<_Tensor_>();
bool flag = Compare(lhs->shape, rhs->shape);
if (only_compare_sturcture_) {
if (only_compare_structure_) {
return flag;
}
return flag && Compare(lhs->name, rhs->name);
Expand Down Expand Up @@ -304,7 +304,7 @@ bool IrEqualVisitor::Visit(const _Module_* lhs, const Expr* other) {
Compare(lhs->functions, rhs->functions) &&
Compare(lhs->submodules, rhs->submodules);

if (only_compare_sturcture_) {
if (only_compare_structure_) {
return flag;
}

Expand Down Expand Up @@ -375,7 +375,7 @@ bool IrEqualVisitor::Visit(const ScheduleBlock* lhs, const Expr* other) {
Compare(lhs->write_buffers, rhs->write_buffers) &&
Compare(lhs->body, rhs->body);

if (only_compare_sturcture_) {
if (only_compare_structure_) {
return flag;
}
return flag && Compare(lhs->attrs, rhs->attrs) &&
Expand Down
6 changes: 3 additions & 3 deletions paddle/cinn/ir/utils/ir_compare.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,9 @@ namespace ir_utils {
class IrEqualVisitor : public IRVisitorRequireReImpl<bool, const Expr*> {
public:
explicit IrEqualVisitor(bool allow_name_suffix_diff = false,
bool only_compare_sturcture = false)
bool only_compare_structure = false)
: allow_name_suffix_diff_(allow_name_suffix_diff),
only_compare_sturcture_(only_compare_sturcture) {}
only_compare_structure_(only_compare_structure) {}
// Return true if they are euqal, otherwise false;
bool Compare(const Expr& lhs, const Expr& rhs);

Expand All @@ -47,7 +47,7 @@ class IrEqualVisitor : public IRVisitorRequireReImpl<bool, const Expr*> {
// whether allowing name suffix ends with "_[0-9]+" different
bool allow_name_suffix_diff_ = false;
// not compare name field of Expr
bool only_compare_sturcture_ = false;
bool only_compare_structure_ = false;
};

bool IRCompare(const Expr& lhs,
Expand Down
4 changes: 0 additions & 4 deletions paddle/cinn/pybind/ir/ir_context.h
Original file line number Diff line number Diff line change
Expand Up @@ -60,15 +60,11 @@ class IRContext {
CHECK(data_.get()) << "IrContext holds null";
auto* ctx_node = data_.get()->safe_as<TIRContextNode>();
if (!ctx_node) {
// TODO(6clc):
std::stringstream err_msg;
err_msg << "TypeConvertError: convert " << data_.get()->type_info()
<< " to " << TIRContextNode::__type_info__;

CINN_THROW(err_msg.str());
// CINN_THROW(...) << "TypeConvertError: convert " <<
// data_.get()->type_info()
// << " to " << TIRContextNode::__type_info__;
}
return ctx_node;
}
Expand Down

0 comments on commit 13395b1

Please sign in to comment.