Skip to content

Commit

Permalink
[PIR]Pir cinn support elementwise pow and reshape (PaddlePaddle#58908)
Browse files Browse the repository at this point in the history
* pir cinn support scale op

* revert code

* fix bug

* pir_cinn_support_elementwise_pow_and_reshape

* update

* rever some code

* update

* update
  • Loading branch information
phlrain authored Nov 12, 2023
1 parent 0f69826 commit dcdfdb6
Show file tree
Hide file tree
Showing 7 changed files with 141 additions and 5 deletions.
8 changes: 8 additions & 0 deletions paddle/cinn/hlir/dialect/operator/ir/ops.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,14 @@
kernel :
func : frobenius_norm

- op : reshape
args : (Tensor x, int[] shape)
output : Tensor(out)
infer_meta :
func : ReshapeInferMeta
kernel :
func : reshape

- op : scale
args : (Tensor x, float scale=1.0, float bias=0.0, bool bias_after_scale=true)
output : Tensor(out)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,8 @@ bool AddBroadcastToElementwisePass::Initialize(pir::IrContext* context) {
ps.Add<AddBrodcastToElementwisePattern<paddle::dialect::SubtractOp>>(context);
ps.Add<AddBrodcastToElementwisePattern<paddle::dialect::MultiplyOp>>(context);
ps.Add<AddBrodcastToElementwisePattern<paddle::dialect::DivideOp>>(context);
ps.Add<AddBrodcastToElementwisePattern<paddle::dialect::ElementwisePowOp>>(
context);

patterns_ = ::pir::FrozenRewritePatternSet(std::move(ps));
return true;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,9 +42,12 @@ std::unordered_map<std::string, OpPatternKind> OpKindMap = {
{"pd_op.exp", OpPatternKind::kElementWise},
{"pd_op.sin", OpPatternKind::kElementWise},
{"pd_op.cos", OpPatternKind::kElementWise},
{"pd_op.pow", OpPatternKind::kElementWise},
{"pd_op.elementwise_pow", OpPatternKind::kElementWise},
{"pd_op.sum", OpPatternKind::kReduction},
{"cinn_op.reshape", OpPatternKind::kElementWise},
{"pd_op.cast", OpPatternKind::kElementWise},
{"pd_op.greater_than", OpPatternKind::kElementWise},
{"pd_op.sum", OpPatternKind::kReduction},
{"cinn_op.scale", OpPatternKind::kElementWise},
{"cinn_op.reduce_sum", OpPatternKind::kReduction},
{"cinn_op.reduce_max", OpPatternKind::kReduction},
Expand Down
44 changes: 44 additions & 0 deletions paddle/cinn/hlir/dialect/operator/transforms/pd_to_cinn_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,49 @@ class ScaleOpPattern : public pir::OpRewritePattern<paddle::dialect::ScaleOp> {
}
};

class ReshapeOpPattern
: public pir::OpRewritePattern<paddle::dialect::ReshapeOp> {
public:
using pir::OpRewritePattern<paddle::dialect::ReshapeOp>::OpRewritePattern;

bool MatchAndRewrite(paddle::dialect::ReshapeOp op,
pir::PatternRewriter &rewriter) const override {
auto scale_factor_gen_op =
op->operand_source(1).dyn_cast<pir::OpResult>().owner();

if (auto full_op =
scale_factor_gen_op->dyn_cast<paddle::dialect::FullIntArrayOp>()) {
// sacle is generator by full op
// get attribute value from full op

auto out_shape_attr =
full_op.attribute("value").dyn_cast<pir::ArrayAttribute>().AsVector();

std::vector<int> vec_out_shape;
if (out_shape_attr.size() > 0) {
PADDLE_ENFORCE_EQ(
out_shape_attr[0].isa<::pir::Int64Attribute>(),
true,
phi::errors::Unimplemented(
"the 0th elementwise MUST be ir::Int64Attribute"));
for (size_t i = 0; i < out_shape_attr.size(); ++i) {
vec_out_shape.push_back(
out_shape_attr[i].dyn_cast<::pir::Int64Attribute>().data());
}
}

auto cinn_reshape = rewriter.Build<cinn::dialect::ReshapeOp>(
op->operand_source(0).dyn_cast<pir::OpResult>(), vec_out_shape);
rewriter.ReplaceAllUsesWith(op.result(0), cinn_reshape.result(0));
rewriter.EraseOp(op);
rewriter.EraseOp(full_op);

return true;
}
return false;
}
};

class UniformOpPattern : public pir::drr::DrrPatternBase<UniformOpPattern> {
public:
void operator()(pir::drr::DrrPatternContext *ctx) const override {
Expand Down Expand Up @@ -186,6 +229,7 @@ bool PdOpToCinnOpPass::Initialize(pir::IrContext *context) {
context); // NOTE, scale op pattern should before AddBroadcastTo
ps.Add(SumOpPattern().Build(context));
ps.Add(MaxOpPattern().Build(context));
ps.Add<ReshapeOpPattern>(context);
// ps.Add(UniformOpPattern().Build(context));

patterns_ = ::pir::FrozenRewritePatternSet(std::move(ps));
Expand Down
14 changes: 11 additions & 3 deletions paddle/cinn/hlir/framework/pir/utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -38,13 +38,21 @@ const std::unordered_map<std::string, std::string> CompatibleInfo::OP_NAMES = {
{"pd_op.add", "elementwise_add"},
{"pd_op.subtract", "subtract"},
{"pd_op.divide", "divide"},
{"pd_op.elementwise_pow", "pow"},
{"pd_op.multiply", "elementwise_mul"},
{"cinn_op.broadcast", "broadcast_to"},
{"cinn_op.scale", "scale"}};
{"cinn_op.reshape", "reshape"},
{"cinn_op.scale", "scale"},
{"cinn_op.broadcast", "broadcast_to"}};

// Tagging PaddleDialect Op with REGITER_OP_MAPPER(OP)
const std::unordered_set<std::string> CompatibleInfo::CINN_WHITE_OPS = {
"subtract", "divide", "broadcast_to", "multiply", "scale"};
"subtract",
"divide",
"broadcast_to",
"multiply",
"scale",
"elementwise_pow",
"reshape"};

bool CompatibleInfo::IsSupportCinn(const ::pir::Operation& op) {
return CINN_WHITE_OPS.find(CompatibleInfo::OpName(op)) !=
Expand Down
5 changes: 4 additions & 1 deletion paddle/fluid/pir/dialect/op_generator/op_creator_drr_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,10 @@ def gen_cpp_file_code(self, cpp_file_path):
if len(op_info_item.attribute_name_list) > 0:
params_no_mutable_attr.append("attrs")

if len(op_info_item.mutable_attribute_name_list) == 0:
if (
self.dialect_name != "pd_op"
or len(op_info_item.mutable_attribute_name_list) == 0
):
body_code += NORMAL_FUNCTION_TEMPLATE.format(
op_name=ir_op_name,
namespace=Dialect2NameSpaceMap[self.dialect_name],
Expand Down
68 changes: 68 additions & 0 deletions test/cpp/pir/cinn/pir_all_path_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -427,3 +427,71 @@ TEST(GroupOp, TestBuildScaleTensor) {
bool res0 = simple_cmp(out_tensor.data<float>()[0], 0.5);
EXPECT_EQ(res0, true);
}

std::shared_ptr<::pir::Program> BuildPowerProgram() {
::pir::IrContext* ctx = ::pir::IrContext::Instance();
ctx->GetOrRegisterDialect<paddle::dialect::OperatorDialect>();
auto program = std::make_shared<::pir::Program>(ctx);
::pir::Builder builder = ::pir::Builder(ctx, program->block());

auto x = builder
.Build<paddle::dialect::FullOp>(std::vector<int64_t>({16, 16}),
2.0,
phi::DataType::FLOAT32,
phi::GPUPlace())
.result(0);

auto factor =
builder
.Build<paddle::dialect::FullOp>(std::vector<int64_t>({16, 16}),
2.0,
phi::DataType::FLOAT32,
phi::GPUPlace())
.result(0);

auto power =
builder.Build<paddle::dialect::ElementwisePowOp>(x, factor).result(0);
auto out =
builder
.Build<paddle::dialect::ReshapeOp>(power, std::vector<int64_t>({-1}))
.result(0);

builder.Build<paddle::dialect::FetchOp>(out, "out", 0);
return program;
}

TEST(GroupOp, TestBuildPower) {
// Step 1: Construct pir::Program
::pir::IrContext* ctx = ::pir::IrContext::Instance();
std::shared_ptr<::pir::Program> program = BuildPowerProgram();
ctx->GetOrRegisterDialect<paddle::dialect::OperatorDialect>();
ctx->GetOrRegisterDialect<cinn::dialect::OperatorDialect>();

cinn::dialect::ir::PdOp2CinnOpConverter(program.get());

pir::PassManager pm(ctx);
pm.AddPass(
std::make_unique<cinn::dialect::ir::AddBroadcastToElementwisePass>());
pm.AddPass(pir::CreateBuildCinnPass());
CHECK_EQ(pm.Run(program.get()), true);

auto res = cinn::dialect::ir::CINNGroupLoweringPass(program.get());

paddle::platform::Place place = paddle::platform::CUDAPlace(0);

auto kernel_program =
paddle::dialect::PdOpLowerToKernelPass(res.get(), place);

paddle::framework::Scope exe_scope;

paddle::framework::InterpreterCore executor(
place, {"out@fetch"}, kernel_program->block(), &exe_scope);

executor.Run({}, true);

auto out_tensor =
executor.local_scope()->FindVar("out@fetch")->Get<phi::DenseTensor>();

bool res0 = simple_cmp(out_tensor.data<float>()[0], 4.0);
EXPECT_EQ(res0, true);
}

0 comments on commit dcdfdb6

Please sign in to comment.