Skip to content

Commit

Permalink
fix UT
Browse files Browse the repository at this point in the history
  • Loading branch information
Aurelius84 committed Nov 22, 2023
1 parent a042b64 commit d9a56d4
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 9 deletions.
10 changes: 9 additions & 1 deletion paddle/cinn/hlir/framework/pir/utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,11 @@ const std::unordered_map<std::string, std::string> CompatibleInfo::OP_NAMES = {
{"pd_op.multiply", "elementwise_mul"},
{"cinn_op.reshape", "reshape"},
{"cinn_op.scale", "scale"},
{"cinn_op.broadcast", "broadcast_to"}};
{"cinn_op.broadcast", "broadcast_to"},
// The following should implement OpPattern in pd_to_cinn_pass,
// otherwise, it will be block in BuildCinnPass.
{"cinn_op.squeeze", ""},
{"cinn_op.unsqueeze", ""}};

// In following cases, the op is marked SupportCinn:
// 1. its name is in OP_NAMES, like pd_op.sum;
Expand Down Expand Up @@ -78,6 +82,9 @@ std::string CompatibleInfo::OpName(const ::pir::Operation& op) {
}
auto cinn_op_name = name.substr(pos + 1);
VLOG(4) << "GetOpName: " << name << " -> " << cinn_op_name;
CHECK(cinn_op_name != "")
<< "Found empty cinn_op_name, maybe you should implement OpPattern for "
<< name;
return cinn_op_name;
}

Expand Down Expand Up @@ -256,6 +263,7 @@ OpPatternKind CompatibleInfo::OpKind(const ::pir::Operation& op) {
kind = hlir::framework::kElementWise;
}
}
VLOG(4) << op_name << " OpPatternKind: " << kind;
return kind;
}

Expand Down
26 changes: 18 additions & 8 deletions paddle/fluid/pir/transforms/build_cinn_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -129,22 +129,32 @@ std::string GetDebugInfo(const std::unordered_set<std::string>& names) {
return debug_info;
}

bool IsSupportCinn(pir::Operation* op) {
auto allow_ops = StringSplit(FLAGS_allow_cinn_ops, kDelim);
auto deny_ops = StringSplit(FLAGS_deny_cinn_ops, kDelim);
VLOG(4) << "The allowed Cinn Ops: " << GetDebugInfo(allow_ops);
VLOG(4) << "The denied Cinn Ops: " << GetDebugInfo(deny_ops);

// In case of op has some attributes generated by FullOp, it need
// implement OpPattern in pd_to_cinn_pass. Otherwise, we mark them
// as unimplement ops.
bool UnimplementOps(pir::Operation* op) {
// cinn not support uniform, the FullOp of max and min support NOT generate by
// CINN
if (op->isa<paddle::dialect::FullOp>()) {
auto out = op->result(0);
if (out.use_count() > 0 &&
out.first_use().owner()->isa<paddle::dialect::UniformOp>()) {
return false;
return true;
}
} else if (op->isa<paddle::dialect::DropoutOp>()) {
return true;
}
if (op->isa<paddle::dialect::DropoutOp>()) {
return false;
}

bool IsSupportCinn(pir::Operation* op) {
auto allow_ops = StringSplit(FLAGS_allow_cinn_ops, kDelim);
auto deny_ops = StringSplit(FLAGS_deny_cinn_ops, kDelim);
VLOG(4) << "The allowed Cinn Ops: " << GetDebugInfo(allow_ops);
VLOG(4) << "The denied Cinn Ops: " << GetDebugInfo(deny_ops);

if (UnimplementOps(op)) {
VLOG(4) << "Found UnimplementOps: " << op->name();
return false;
}

Expand Down

0 comments on commit d9a56d4

Please sign in to comment.