Skip to content

Commit

Permalink
[PIR] support pd_op.expand convert to cinn_opbroadcast_to (#59437)
Browse files Browse the repository at this point in the history
* pir cinn support multi group

* update

* update

* fix pir cinn pow op bug

* remove useless code

* update

* update
  • Loading branch information
phlrain authored Nov 30, 2023
1 parent 6e5aad0 commit 2d89d60
Show file tree
Hide file tree
Showing 4 changed files with 70 additions and 15 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
#include "paddle/cinn/hlir/dialect/operator/transforms/add_broadcast_to_elementwise_pass.h"

#include "paddle/cinn/hlir/dialect/operator/ir/cinn_op.h"
#include "paddle/cinn/hlir/framework/pir/utils.h"
#include "paddle/fluid/pir/dialect/operator/ir/op_attribute.h"
#include "paddle/fluid/pir/dialect/operator/ir/op_type.h"
#include "paddle/fluid/pir/dialect/operator/ir/pd_op.h"
Expand Down Expand Up @@ -82,19 +83,6 @@ bool IsSameDim(const phi::DDim& first, const std::vector<int64_t>& second) {
return false;
}

std::vector<int64_t> GetBroadcastAxis(const phi::DDim& in_shape,
const std::vector<int64_t>& out_shape) {
std::vector<int64_t> broadcast_axes(in_shape.size(), 0);
auto in_shape_size = in_shape.size();
if (in_shape_size >= 1) {
for (int i = 1; i <= in_shape_size; ++i) {
broadcast_axes[in_shape_size - i] = out_shape.size() - i;
}
}

return broadcast_axes;
}

bool ProcessOp(pir::Operation* op, pir::PatternRewriter* rewriter) {
auto x_dims = op->operand_source(0)
.type()
Expand Down Expand Up @@ -126,7 +114,7 @@ bool ProcessOp(pir::Operation* op, pir::PatternRewriter* rewriter) {
} else {
auto new_transpose_op = rewriter->Build<cinn::dialect::BroadcastOp>(
op->operand_source(0),
GetBroadcastAxis(x_dims, output_shape),
cinn::hlir::framework::pir::GetBroadcastAxis(x_dims, output_shape),
output_shape);

op->operand(0).set_source(new_transpose_op->result(0));
Expand All @@ -152,7 +140,7 @@ bool ProcessOp(pir::Operation* op, pir::PatternRewriter* rewriter) {
} else {
auto new_transpose_op = rewriter->Build<cinn::dialect::BroadcastOp>(
op->operand_source(1),
GetBroadcastAxis(y_dims, output_shape),
cinn::hlir::framework::pir::GetBroadcastAxis(y_dims, output_shape),
output_shape);

op->operand(1).set_source(new_transpose_op->result(0));
Expand Down
50 changes: 50 additions & 0 deletions paddle/cinn/hlir/dialect/operator/transforms/pd_to_cinn_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#include "paddle/cinn/hlir/dialect/operator/ir/cinn_op.h"
#include "paddle/cinn/hlir/dialect/operator/ir/manual_op.h"
#include "paddle/cinn/hlir/dialect/operator/transforms/group_merge/op_with_group_merge_util.h"
#include "paddle/cinn/hlir/framework/pir/utils.h"
#include "paddle/fluid/pir/dialect/operator/ir/pd_op.h"
#include "paddle/fluid/pir/drr/api/drr_pattern_base.h"
#include "paddle/fluid/pir/drr/api/match_context.h"
Expand Down Expand Up @@ -356,6 +357,54 @@ class AddNOpPattern : public pir::OpRewritePattern<paddle::dialect::AddNOp> {
}
};

class ExpandOpPattern
: public pir::OpRewritePattern<paddle::dialect::ExpandOp> {
public:
using pir::OpRewritePattern<paddle::dialect::ExpandOp>::OpRewritePattern;

bool MatchAndRewrite(paddle::dialect::ExpandOp op,
pir::PatternRewriter &rewriter) const override {
auto out_shape_gen_op = op->operand_source(1)
.dyn_cast<pir::OpResult>()
.owner()
->dyn_cast<paddle::dialect::FullIntArrayOp>();

if (out_shape_gen_op) {
auto section_attr = out_shape_gen_op.attribute("value")
.dyn_cast<pir::ArrayAttribute>()
.AsVector();

std::vector<int64_t> output_shape;
if (section_attr.size() > 0) {
for (size_t i = 0; i < section_attr.size(); ++i) {
output_shape.push_back(
section_attr[i].dyn_cast<::pir::Int64Attribute>().data());
}
}

auto in_dim = op.operand_source(0)
.type()
.dyn_cast<paddle::dialect::DenseTensorType>()
.dims();

auto broadcast_axis =
cinn::hlir::framework::pir::GetBroadcastAxis(in_dim, output_shape);

auto out = rewriter
.Build<cinn::dialect::BroadcastOp>(
op.operand_source(0), broadcast_axis, output_shape)
.result(0);

rewriter.ReplaceAllUsesWith(op.result(0), out);

rewriter.EraseOp(op);
return true;
}

return false;
}
};

class SplitWithNumOpPattern
: public pir::OpRewritePattern<paddle::dialect::SplitWithNumOp> {
public:
Expand Down Expand Up @@ -478,6 +527,7 @@ pir::RewritePatternSet PdOpToCinnOpPass::InitializePatterns(
ps.Add<SplitWithNumOpPattern>(context);
ps.Add<AddNOpPattern>(context);
ps.Add<SplitOpPattern>(context);
ps.Add<ExpandOpPattern>(context);
// ps.Add(UniformOpPattern().Build(context));

return ps;
Expand Down
13 changes: 13 additions & 0 deletions paddle/cinn/hlir/framework/pir/utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -273,6 +273,19 @@ std::vector<int> CompatibleInfo::ValueShape(const ::pir::Value& value) {
return phi::vectorize<int>(dim);
}

std::vector<int64_t> GetBroadcastAxis(const phi::DDim& in_shape,
const std::vector<int64_t>& out_shape) {
std::vector<int64_t> broadcast_axes(in_shape.size(), 0);
auto in_shape_size = in_shape.size();
if (in_shape_size >= 1) {
for (int i = 1; i <= in_shape_size; ++i) {
broadcast_axes[in_shape_size - i] = out_shape.size() - i;
}
}

return broadcast_axes;
}

} // namespace pir
} // namespace framework
} // namespace hlir
Expand Down
4 changes: 4 additions & 0 deletions paddle/cinn/hlir/framework/pir/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
#include "paddle/cinn/common/type.h"
#include "paddle/cinn/hlir/framework/op.h"
#include "paddle/cinn/utils/type_defs.h"
#include "paddle/phi/core/ddim.h"
#include "paddle/pir/core/operation.h"

namespace cinn {
Expand Down Expand Up @@ -77,6 +78,9 @@ struct CompatibleInfo {
static OpPatternKind OpKind(const ::pir::Operation& op);
};

std::vector<int64_t> GetBroadcastAxis(const phi::DDim& in_shape,
const std::vector<int64_t>& out_shape);

} // namespace pir
} // namespace framework
} // namespace hlir
Expand Down

0 comments on commit 2d89d60

Please sign in to comment.