Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[PIR] support pd_op.expand convert to cinn_opbroadcast_to #59437

Merged
merged 17 commits into from
Nov 30, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
#include "paddle/cinn/hlir/dialect/operator/transforms/add_broadcast_to_elementwise_pass.h"

#include "paddle/cinn/hlir/dialect/operator/ir/cinn_op.h"
#include "paddle/cinn/hlir/framework/pir/utils.h"
#include "paddle/fluid/pir/dialect/operator/ir/op_attribute.h"
#include "paddle/fluid/pir/dialect/operator/ir/op_type.h"
#include "paddle/fluid/pir/dialect/operator/ir/pd_op.h"
Expand Down Expand Up @@ -82,19 +83,6 @@ bool IsSameDim(const phi::DDim& first, const std::vector<int64_t>& second) {
return false;
}

std::vector<int64_t> GetBroadcastAxis(const phi::DDim& in_shape,
const std::vector<int64_t>& out_shape) {
std::vector<int64_t> broadcast_axes(in_shape.size(), 0);
auto in_shape_size = in_shape.size();
if (in_shape_size >= 1) {
for (int i = 1; i <= in_shape_size; ++i) {
broadcast_axes[in_shape_size - i] = out_shape.size() - i;
}
}

return broadcast_axes;
}

bool ProcessOp(pir::Operation* op, pir::PatternRewriter* rewriter) {
auto x_dims = op->operand_source(0)
.type()
Expand Down Expand Up @@ -126,7 +114,7 @@ bool ProcessOp(pir::Operation* op, pir::PatternRewriter* rewriter) {
} else {
auto new_transpose_op = rewriter->Build<cinn::dialect::BroadcastOp>(
op->operand_source(0),
GetBroadcastAxis(x_dims, output_shape),
cinn::hlir::framework::pir::GetBroadcastAxis(x_dims, output_shape),
output_shape);

op->operand(0).set_source(new_transpose_op->result(0));
Expand All @@ -152,7 +140,7 @@ bool ProcessOp(pir::Operation* op, pir::PatternRewriter* rewriter) {
} else {
auto new_transpose_op = rewriter->Build<cinn::dialect::BroadcastOp>(
op->operand_source(1),
GetBroadcastAxis(y_dims, output_shape),
cinn::hlir::framework::pir::GetBroadcastAxis(y_dims, output_shape),
output_shape);

op->operand(1).set_source(new_transpose_op->result(0));
Expand Down
50 changes: 50 additions & 0 deletions paddle/cinn/hlir/dialect/operator/transforms/pd_to_cinn_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#include "paddle/cinn/hlir/dialect/operator/ir/cinn_op.h"
#include "paddle/cinn/hlir/dialect/operator/ir/manual_op.h"
#include "paddle/cinn/hlir/dialect/operator/transforms/group_merge/op_with_group_merge_util.h"
#include "paddle/cinn/hlir/framework/pir/utils.h"
#include "paddle/fluid/pir/dialect/operator/ir/pd_op.h"
#include "paddle/fluid/pir/drr/api/drr_pattern_base.h"
#include "paddle/fluid/pir/drr/api/match_context.h"
Expand Down Expand Up @@ -356,6 +357,54 @@ class AddNOpPattern : public pir::OpRewritePattern<paddle::dialect::AddNOp> {
}
};

class ExpandOpPattern
: public pir::OpRewritePattern<paddle::dialect::ExpandOp> {
public:
using pir::OpRewritePattern<paddle::dialect::ExpandOp>::OpRewritePattern;

bool MatchAndRewrite(paddle::dialect::ExpandOp op,
pir::PatternRewriter &rewriter) const override {
auto out_shape_gen_op = op->operand_source(1)
.dyn_cast<pir::OpResult>()
.owner()
->dyn_cast<paddle::dialect::FullIntArrayOp>();

if (out_shape_gen_op) {
auto section_attr = out_shape_gen_op.attribute("value")
.dyn_cast<pir::ArrayAttribute>()
.AsVector();

std::vector<int64_t> output_shape;
if (section_attr.size() > 0) {
for (size_t i = 0; i < section_attr.size(); ++i) {
output_shape.push_back(
section_attr[i].dyn_cast<::pir::Int64Attribute>().data());
}
}

auto in_dim = op.operand_source(0)
.type()
.dyn_cast<paddle::dialect::DenseTensorType>()
.dims();

auto broadcast_axis =
cinn::hlir::framework::pir::GetBroadcastAxis(in_dim, output_shape);

auto out = rewriter
.Build<cinn::dialect::BroadcastOp>(
op.operand_source(0), broadcast_axis, output_shape)
.result(0);

rewriter.ReplaceAllUsesWith(op.result(0), out);

rewriter.EraseOp(op);
return true;
}

return false;
}
};

class SplitWithNumOpPattern
: public pir::OpRewritePattern<paddle::dialect::SplitWithNumOp> {
public:
Expand Down Expand Up @@ -478,6 +527,7 @@ pir::RewritePatternSet PdOpToCinnOpPass::InitializePatterns(
ps.Add<SplitWithNumOpPattern>(context);
ps.Add<AddNOpPattern>(context);
ps.Add<SplitOpPattern>(context);
ps.Add<ExpandOpPattern>(context);
// ps.Add(UniformOpPattern().Build(context));

return ps;
Expand Down
13 changes: 13 additions & 0 deletions paddle/cinn/hlir/framework/pir/utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -273,6 +273,19 @@ std::vector<int> CompatibleInfo::ValueShape(const ::pir::Value& value) {
return phi::vectorize<int>(dim);
}

std::vector<int64_t> GetBroadcastAxis(const phi::DDim& in_shape,
const std::vector<int64_t>& out_shape) {
std::vector<int64_t> broadcast_axes(in_shape.size(), 0);
auto in_shape_size = in_shape.size();
if (in_shape_size >= 1) {
for (int i = 1; i <= in_shape_size; ++i) {
broadcast_axes[in_shape_size - i] = out_shape.size() - i;
}
}

return broadcast_axes;
}

} // namespace pir
} // namespace framework
} // namespace hlir
Expand Down
4 changes: 4 additions & 0 deletions paddle/cinn/hlir/framework/pir/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
#include "paddle/cinn/common/type.h"
#include "paddle/cinn/hlir/framework/op.h"
#include "paddle/cinn/utils/type_defs.h"
#include "paddle/phi/core/ddim.h"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个后续建议使用common中ddim.h

#include "paddle/pir/core/operation.h"

namespace cinn {
Expand Down Expand Up @@ -77,6 +78,9 @@ struct CompatibleInfo {
static OpPatternKind OpKind(const ::pir::Operation& op);
};

std::vector<int64_t> GetBroadcastAxis(const phi::DDim& in_shape,
const std::vector<int64_t>& out_shape);

} // namespace pir
} // namespace framework
} // namespace hlir
Expand Down