diff --git a/paddle/cinn/hlir/dialect/operator/ir/ops.yaml b/paddle/cinn/hlir/dialect/operator/ir/ops.yaml index 22006e1ae4570..2e42323782839 100644 --- a/paddle/cinn/hlir/dialect/operator/ir/ops.yaml +++ b/paddle/cinn/hlir/dialect/operator/ir/ops.yaml @@ -74,6 +74,7 @@ func : SliceRawInferMeta kernel : func : slice + interfaces : paddle::dialect::InferSymbolicShapeInterface - op : uniform_random args : (int64_t[] shape, float min, float max, int seed, DataType dtype, int diag_num = 0, int diag_step=0, float diag_val=1.0) diff --git a/paddle/fluid/inference/CMakeLists.txt b/paddle/fluid/inference/CMakeLists.txt index 3f4e7a9344a30..a98f097e43d23 100644 --- a/paddle/fluid/inference/CMakeLists.txt +++ b/paddle/fluid/inference/CMakeLists.txt @@ -64,8 +64,9 @@ set(KERNEL_LIST # shared inference library deps list(REMOVE_DUPLICATES fluid_modules) -#windows GPU static library over the limit, so not create_static_lib, and cc_library is dummy -if(WIN32 AND WITH_GPU) +# windows static library(both CPU and GPU)over the limit, so no longer create_static_lib, +# and cc_library is dummy +if(WIN32) cc_library(paddle_inference DEPS ${fluid_modules} ${STATIC_INFERENCE_API} ${utils_modules}) else() diff --git a/paddle/fluid/pir/dialect/op_generator/infer_symbolic_shape_gen.py b/paddle/fluid/pir/dialect/op_generator/infer_symbolic_shape_gen.py index d85ed967418d5..ff2094a3df009 100644 --- a/paddle/fluid/pir/dialect/op_generator/infer_symbolic_shape_gen.py +++ b/paddle/fluid/pir/dialect/op_generator/infer_symbolic_shape_gen.py @@ -13,11 +13,9 @@ # limitations under the License. OP_GET_KERNEL_TYPE_FOR_VAR_TEMPLATE = """ -bool {op_name}::InferSymbolicShape(pir::Builder &builder, - const std::vector &operands, - std::vector &reified_return_shapes) {{ +bool {op_name}::InferSymbolicShape(pir::ShapeConstraintIRAnalysis* shape_analysis) {{ VLOG(4) << "Infer symbolic shape for op: {op_name}"; - return {op_name}InferSymbolicShape(builder, operands, reified_return_shapes); + return {op_name}InferSymbolicShape(this->operation(), shape_analysis); }} """ diff --git a/paddle/fluid/pir/dialect/op_generator/op_gen.py b/paddle/fluid/pir/dialect/op_generator/op_gen.py index 7dd754e868f86..e7a131a98d05b 100644 --- a/paddle/fluid/pir/dialect/op_generator/op_gen.py +++ b/paddle/fluid/pir/dialect/op_generator/op_gen.py @@ -131,9 +131,7 @@ class {TEST_API} {op_name} : public pir::Op<{op_name}{interfaces}{traits}> {{ """ infer_symbolic_shape_template = """ - static bool InferSymbolicShape(pir::Builder &builder, - const std::vector &operands, - std::vector &reified_return_shapes); + bool InferSymbolicShape(pir::ShapeConstraintIRAnalysis* shape_analysis); """ # ===================================== diff --git a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape.cc b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape.cc index 676e4b9d574b9..1b9ca43b7d9f1 100644 --- a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape.cc +++ b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape.cc @@ -13,16 +13,15 @@ // limitations under the License. #include "paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape.h" +#include "paddle/pir/core/builtin_attribute.h" +#include "paddle/pir/core/builtin_type.h" #include "paddle/pir/dialect/shape/ir/shape_op.h" namespace paddle::dialect { bool InferSymbolicShapeInterface::InferSymbolicShape( - pir::Builder &builder, - const std::vector &operands, - std::vector &reified_return_shapes) { - return impl_->infer_symbolic_shapes( - operation(), builder, operands, reified_return_shapes); + pir::ShapeConstraintIRAnalysis *shape_analysis) { + return impl_->infer_symbolic_shapes(operation(), shape_analysis); } } // namespace paddle::dialect @@ -30,124 +29,176 @@ namespace paddle::dialect { namespace { -bool DeriveShapeFromOperand(pir::Builder *builder, - pir::Value operand, - std::vector *reified_return_shapes) { - auto shaped_type = operand.type().dyn_cast(); - if (!shaped_type) return false; - reified_return_shapes->assign( - {builder->Build(operand).result(0)}); +bool InferSymbolicShapeAllEqualUnary( + pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) { + pir::Value operand_source = op->operand_source(0); + std::string operand_source_id = pir::GetValueId(&operand_source); + pir::OpResult res = op->result(0); + std::string res_id = pir::GetValueId(&res); + shape_analysis->value_id_to_shapeordata_[res_id] = + shape_analysis->value_id_to_shapeordata_[operand_source_id]; return true; } -// Returns a new scalar integer value having type `type`. -// Here `type` must be an integer or index type. -pir::Value MaybeCastTo(pir::Builder &builder, // NOLINT - pir::Value value, - pir::Type type) { - if (type == value.type()) return value; - // if (!type.IsIndex() && !value.type().IsIndex()) { - // Value casted = - // builder.Build(builder.index_type(), value) - // .result(0); - // return builder.Build(type, casted).result(0); - // } - // return builder.Build(type, value).result(0); +bool InferSymbolicShapeAllEqualBinary( + pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) { + pir::Value operand_source = op->operand_source(0); + std::string operand_source_id = pir::GetValueId(&operand_source); + pir::OpResult res = op->result(0); + std::string res_id = pir::GetValueId(&res); + shape_analysis->value_id_to_shapeordata_[res_id] = + shape_analysis->value_id_to_shapeordata_[operand_source_id]; + return true; } + } // namespace -bool AbsOpInferSymbolicShape( - pir::Builder &builder, // NOLINT - const std::vector &operands, - std::vector &reified_return_shapes) { // NOLINT - return DeriveShapeFromOperand( - &builder, operands.front().source(), &reified_return_shapes); -} - -bool Abs_OpInferSymbolicShape( - pir::Builder &builder, // NOLINT - const std::vector &operands, - std::vector &reified_return_shapes) { // NOLINT - return DeriveShapeFromOperand( - &builder, operands.front().source(), &reified_return_shapes); -} - -bool TransposeOpInferSymbolicShape( - pir::Builder &builder, // NOLINT - const std::vector &operands, - std::vector &reified_return_shapes) { // NOLINT - // auto operand_type = operands[0].type().dyn_cast(); - // // Currently not support unranked type. - // if (!operand_type) return false; - // std::vector permutation = this->permutation(); - // std::vector shape_values(permutation.size()); - // Type shape_scalar_type = builder.index_type(); - // auto to_shape_scalar_type = [&](Value v) { - // return MaybeCastTo(builder, v, shape_scalar_type); - // }; - // auto shaped_type = operand_type.dyn_cast(); - // auto shape_vector = shaped_type.GetDyShape(); - // for (auto [idx, element] = std::tuple{0, shape_vector.begin()}; - // element != shape_vector.end(); - // ++idx, ++element) { - // auto it = std::find(permutation.begin(), permutation.end(), idx); - // // TODO(zhangbopd): Need BuildOrFold - // Value value_dim = to_shape_scalar_type( - // builder.Build(operands[0].source(), - // idx).result(0)); - // shape_values[std::distance(permutation.begin(), it)] = value_dim; - // } - // Value output_shape = - // builder.Build(shape_values).result(0); - // reified_return_shapes.push_back(output_shape); +bool AbsOpInferSymbolicShape(pir::Operation *op, + pir::ShapeConstraintIRAnalysis *shape_analysis) { + return InferSymbolicShapeAllEqualUnary(op, shape_analysis); +} + +bool Abs_OpInferSymbolicShape(pir::Operation *op, + pir::ShapeConstraintIRAnalysis *shape_analysis) { + return InferSymbolicShapeAllEqualUnary(op, shape_analysis); +} + +bool CastOpInferSymbolicShape(pir::Operation *op, + pir::ShapeConstraintIRAnalysis *shape_analysis) { + return InferSymbolicShapeAllEqualUnary(op, shape_analysis); +} + +bool Cast_OpInferSymbolicShape(pir::Operation *op, + pir::ShapeConstraintIRAnalysis *shape_analysis) { + return InferSymbolicShapeAllEqualUnary(op, shape_analysis); +} + +bool ExpOpInferSymbolicShape(pir::Operation *op, + pir::ShapeConstraintIRAnalysis *shape_analysis) { + return InferSymbolicShapeAllEqualUnary(op, shape_analysis); +} +bool Exp_OpInferSymbolicShape(pir::Operation *op, + pir::ShapeConstraintIRAnalysis *shape_analysis) { + return InferSymbolicShapeAllEqualUnary(op, shape_analysis); +} + +bool SubtractOpInferSymbolicShape( + pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) { + return InferSymbolicShapeAllEqualBinary(op, shape_analysis); +} + +bool Subtract_OpInferSymbolicShape( + pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) { + return InferSymbolicShapeAllEqualBinary(op, shape_analysis); +} + +bool ShapeOpInferSymbolicShape(pir::Operation *op, + pir::ShapeConstraintIRAnalysis *shape_analysis) { + pir::Value operand_source = op->operand_source(0); + std::string operand_source_id = pir::GetValueId(&operand_source); + pir::OpResult res = op->result(0); + std::string res_id = pir::GetValueId(&res); + + std::vector dims = + common::vectorize(res.type().dyn_cast().dims()); + + std::vector shapes; + for (int64_t dim : dims) { + symbol::DimExpr dim_expr; + if (dim == -1) { + symbol::DimExpr res_dim_expr(shape_analysis->GetNextSymName()); + dim_expr = res_dim_expr; + } else { + symbol::DimExpr res_dim_expr(dim); + dim_expr = res_dim_expr; + } + shapes.push_back(dim_expr); + } + + symbol::ShapeOrDataDimExprs shape_data{shapes}; + shape_analysis->value_id_to_shapeordata_[res_id] = shape_data; return true; } -bool ConcatOpInferSymbolicShape( - pir::Builder &builder, // NOLINT - const std::vector &operands, - std::vector &reified_return_shapes) { // NOLINT - // std::vector inputs = {x()}; - // auto operand_type = inputs[0].type().dyn_cast(); - // // Currently not support unranked type. - // if (!operand_type) return false; - // Type shapeScalarType = builder.index_type(); - // auto to_shape_scalar_type = [&](Value v) { - // return MaybeCastTo(builder, v, shapeScalarType); - // }; - // std::vector> all_shape_values; - // for (size_t inputId = 0; inputId < inputs.size(); ++inputId) { - // Value operand = inputs[inputId]; - // auto operand_type = operand.type().dyn_cast(); - // if (!operand_type) return false; - // std::vector shape_values; - // auto shaped_type = operand_type.dyn_cast(); - // auto shape_vector = shaped_type.GetDyShape(); - // for (auto [idx, element] = std::tuple{0, shape_vector.begin()}; - // element != shape_vector.end(); - // ++idx, ++element) { - // Value value_dim = to_shape_scalar_type( - // builder.Build(operand, idx).result(0)); - // shape_values.push_back(value_dim); - // } - // all_shape_values.emplace_back(std::move(shape_values)); - // } - // [[maybe_unused]] int axis = this->dimension(); - // auto &shape_values = all_shape_values[0]; - // for (size_t vecId = 1; vecId < all_shape_values.size(); ++vecId) { - // auto &otherShapeValues = all_shape_values[vecId]; - // if (otherShapeValues.size() != shape_values.size()) return false; - // TODO(zhangbopd): AddIOp - // shape_values[axis] = - // builder.Build(shape_values[axis], - // otherShapeValues[axis]); - // } - // Value output_shape = - // builder.Build(shape_values).result(0); - // reified_return_shapes.push_back(output_shape); +bool ShapeSrOpInferSymbolicShape( + pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) { + return ShapeOpInferSymbolicShape(op, shape_analysis); +} + +bool StackOpInferSymbolicShape(pir::Operation *op, + pir::ShapeConstraintIRAnalysis *shape_analysis) { + pir::Value operand_source = op->operand_source(0); + std::string operand_source_id = pir::GetValueId(&operand_source); + pir::OpResult res = op->result(0); + std::string res_id = pir::GetValueId(&res); + + symbol::ShapeOrDataDimExprs shape_data; + shape_data = shape_analysis->value_id_to_shapeordata_[operand_source_id]; + shape_analysis->value_id_to_shapeordata_[res_id] = shape_data; + return true; +} + +bool ReshapeOpInferSymbolicShape( + pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) { + pir::Value operand_source_1 = op->operand_source(1); + std::string operand_source_1_id = pir::GetValueId(&operand_source_1); + pir::OpResult res = op->result(0); + std::string res_id = pir::GetValueId(&res); + + symbol::ShapeOrDataDimExprs shape_data; + + shape_data = shape_analysis->value_id_to_shapeordata_[operand_source_1_id]; + shape_analysis->value_id_to_shapeordata_[res_id] = shape_data; return true; } +bool Reshape_OpInferSymbolicShape( + pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) { + return ReshapeOpInferSymbolicShape(op, shape_analysis); +} + } // namespace paddle::dialect +namespace cinn::dialect { + +bool SliceOpInferSymbolicShape(pir::Operation *op, + pir::ShapeConstraintIRAnalysis *shape_analysis) { + pir::Value operand_source = op->operand_source(0); + std::string operand_source_id = pir::GetValueId(&operand_source); + pir::OpResult res = op->result(0); + std::string res_id = pir::GetValueId(&res); + + std::vector dims = + common::vectorize(res.type().dyn_cast().dims()); + + std::vector shapes; + for (int64_t dim : dims) { + symbol::DimExpr dim_expr; + if (dim == -1) { + symbol::DimExpr res_dim_expr(shape_analysis->GetNextSymName()); + dim_expr = res_dim_expr; + } else { + symbol::DimExpr res_dim_expr(dim); + dim_expr = res_dim_expr; + } + shapes.push_back(dim_expr); + } + + // pir::AttributeMap attributes = op->attributes(); + + // auto attr_starts = + // attributes["starts"].dyn_cast().AsVector(); + // auto start = attr_starts[0].dyn_cast().data(); + + // auto attr_ends = + // attributes["ends"].dyn_cast().AsVector(); + // auto end = attr_ends[0].dyn_cast().data(); + + symbol::ShapeOrDataDimExprs shape_data{shapes}; + shape_analysis->value_id_to_shapeordata_[res_id] = shape_data; + return true; +} + +} // namespace cinn::dialect + IR_DEFINE_EXPLICIT_TYPE_ID(paddle::dialect::InferSymbolicShapeInterface) diff --git a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape.h b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape.h index 46ccf56183b2a..b1c72e3111df2 100644 --- a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape.h +++ b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape.h @@ -15,6 +15,7 @@ #pragma once #include "paddle/pir/core/op_base.h" +#include "paddle/pir/dialect/shape/utils/shape_utils.h" // Type inference is currently modelled executionally for operation creation // using the `InferMetaInterface`. While `InferSymbolicShapeInterface` is used @@ -31,54 +32,82 @@ class InferSymbolicShapeInterface /// Defined these methods with the interface. struct Concept { explicit Concept(bool (*infer_symbolic_shapes)( - pir::Operation* op, - pir::Builder& builder, // NOLINT - const std::vector& operands, - std::vector& reified_return_shapes)) // NOLINT + pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis)) : infer_symbolic_shapes(infer_symbolic_shapes) {} bool (*infer_symbolic_shapes)( - pir::Operation* op, - pir::Builder& builder, - const std::vector& operands, - std::vector& reified_return_shapes); // NOLINT + pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis); }; template struct Model : public Concept { static inline bool InferSymbolicShape( - pir::Operation* op, - pir::Builder& builder, // NOLINT - const std::vector& operands, - std::vector& reified_return_shapes) { // NOLINT - return op->dyn_cast().InferSymbolicShape( - builder, operands, reified_return_shapes); + pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) { + return op->dyn_cast().InferSymbolicShape(shape_analysis); } Model() : Concept(InferSymbolicShape) {} }; /// Constructor - InferSymbolicShapeInterface(pir::Operation* op, Concept* impl) + InferSymbolicShapeInterface(pir::Operation *op, Concept *impl) : pir::OpInterfaceBase(op), impl_(impl) {} - bool InferSymbolicShape( - pir::Builder& builder, // NOLINT - const std::vector& operands, - std::vector& reified_return_shapes); // NOLINT + bool InferSymbolicShape(pir::ShapeConstraintIRAnalysis *shape_analysis); private: - Concept* impl_; + Concept *impl_; }; -bool AbsOpInferSymbolicShape( - pir::Builder& builder, // NOLINT - const std::vector& operands, - std::vector& reified_return_shapes); // NOLINT -bool Abs_OpInferSymbolicShape( - pir::Builder& builder, // NOLINT - const std::vector& operands, - std::vector& reified_return_shapes); // NOLINT +} // namespace paddle::dialect + +namespace paddle::dialect { + +bool AbsOpInferSymbolicShape(pir::Operation *op, + pir::ShapeConstraintIRAnalysis *shape_analysis); + +bool Abs_OpInferSymbolicShape(pir::Operation *op, + pir::ShapeConstraintIRAnalysis *shape_analysis); + +bool CastOpInferSymbolicShape(pir::Operation *op, + pir::ShapeConstraintIRAnalysis *shape_analysis); + +bool Cast_OpInferSymbolicShape(pir::Operation *op, + pir::ShapeConstraintIRAnalysis *shape_analysis); + +bool ExpOpInferSymbolicShape(pir::Operation *op, + pir::ShapeConstraintIRAnalysis *shape_analysis); + +bool Exp_OpInferSymbolicShape(pir::Operation *op, + pir::ShapeConstraintIRAnalysis *shape_analysis); + +bool SubtractOpInferSymbolicShape( + pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis); + +bool Subtract_OpInferSymbolicShape( + pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis); + +bool ShapeOpInferSymbolicShape(pir::Operation *op, + pir::ShapeConstraintIRAnalysis *shape_analysis); + +bool ShapeSrOpInferSymbolicShape( + pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis); + +bool StackOpInferSymbolicShape(pir::Operation *op, + pir::ShapeConstraintIRAnalysis *shape_analysis); + +bool ReshapeOpInferSymbolicShape( + pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis); + +bool Reshape_OpInferSymbolicShape( + pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis); } // namespace paddle::dialect +namespace cinn::dialect { + +bool SliceOpInferSymbolicShape(pir::Operation *op, + pir::ShapeConstraintIRAnalysis *shape_analysis); + +} + IR_DECLARE_EXPLICIT_TYPE_ID(paddle::dialect::InferSymbolicShapeInterface) diff --git a/paddle/fluid/pir/dialect/operator/ir/op_dialect.cc b/paddle/fluid/pir/dialect/operator/ir/op_dialect.cc index 7b5959a542e7a..6e2e105d9c18a 100644 --- a/paddle/fluid/pir/dialect/operator/ir/op_dialect.cc +++ b/paddle/fluid/pir/dialect/operator/ir/op_dialect.cc @@ -29,6 +29,32 @@ namespace paddle { namespace dialect { +struct CombineOpInferSymbolicShapeInterfaceModel + : public InferSymbolicShapeInterface::Concept { + static inline bool InferSymbolicShape( + pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) { + symbol::ShapeOrDataDimExprs value_shape; + + // for (auto operand_source : op->operands_source()) { + // std::string operand_source_id = pir::GetValueId(&operand_source); + // auto source_shape_vec = + // shape_analysis->value_id_to_shapeordata_[operand_source_id]; + // for (int i = 0; i < source_shape_vec.size(); i++) { + // value_shape.second.emplace_back(source_shape_vec[i]); + // } + // } + + auto res = op->result(0); + auto res_id = pir::GetValueId(&res); + + shape_analysis->value_id_to_shapeordata_[res_id] = value_shape; + return true; + } + + CombineOpInferSymbolicShapeInterfaceModel() + : InferSymbolicShapeInterface::Concept(InferSymbolicShape) {} +}; + OperatorDialect::OperatorDialect(pir::IrContext *ctx) : pir::Dialect(name(), ctx, pir::TypeId::get()) { initialize(); @@ -36,6 +62,11 @@ OperatorDialect::OperatorDialect(pir::IrContext *ctx) auto info = ctx->GetRegisteredOpInfo(pir::TuplePushOp::name()); info.AttachInterface(std::move( pir::InterfaceValue::Get())); + + info = ctx->GetRegisteredOpInfo(pir::CombineOp::name()); + info.AttachInterface(std::move( + pir::InterfaceValue::Get())); } void OperatorDialect::initialize() { diff --git a/paddle/fluid/pir/dialect/operator/ir/ops.yaml b/paddle/fluid/pir/dialect/operator/ir/ops.yaml index 57d7857a2498c..b4e7ddf02179c 100644 --- a/paddle/fluid/pir/dialect/operator/ir/ops.yaml +++ b/paddle/fluid/pir/dialect/operator/ir/ops.yaml @@ -265,6 +265,7 @@ data_type : x inplace: (x -> out) backward : cast_grad + interfaces : paddle::dialect::InferSymbolicShapeInterface - op : channel_shuffle args : (Tensor x, int groups, str data_format="NCHW") @@ -1044,6 +1045,7 @@ view: (x -> out) intermediate : xshape backward: reshape_grad + interfaces : paddle::dialect::InferSymbolicShapeInterface - op : rnn args: (Tensor x, Tensor[] pre_state, Tensor[] weight_list, Tensor sequence_length, Tensor dropout_state_in, float dropout_prob=0.0, bool is_bidirec=false, int input_size=10, int hidden_size=100, int num_layers=1, str mode="RNN_TANH", int seed=0, bool is_test=false) @@ -1214,6 +1216,7 @@ func : subtract inplace : (x -> out) backward : subtract_grad + interfaces : paddle::dialect::InferSymbolicShapeInterface - op : sum args : (Tensor x, IntArray axis={}, DataType dtype=DataType::UNDEFINED, bool keepdim=false) diff --git a/paddle/fluid/pir/transforms/shape_optimization_pass.cc b/paddle/fluid/pir/transforms/shape_optimization_pass.cc index a7d32c6577906..5c6481110034e 100644 --- a/paddle/fluid/pir/transforms/shape_optimization_pass.cc +++ b/paddle/fluid/pir/transforms/shape_optimization_pass.cc @@ -111,8 +111,8 @@ class InferSymbolicShapePass : public pir::Pass { if (it != infer_sym_shape_map.end()) { it->second(op, shape_analysis_); } else { - VLOG(3) << "[" << op.name() - << "] is not supported for infer_symbolic_shape pass."; + LOG(WARNING) << "[" << op.name() + << "] is not supported for infer_symbolic_shape pass."; } } @@ -206,7 +206,7 @@ struct ExpandShapeOfOpPattern : public OpRewritePattern { bool MatchAndRewrite(shape::ShapeOfOp op, PatternRewriter& rewriter) const override { - VLOG(5) << "Apply ExpandShapeOfOpPattern..."; + VLOG(3) << "Apply ExpandShapeOfOpPattern..."; auto type = op.out().type().dyn_cast(); @@ -233,44 +233,6 @@ struct DimOfShapedTypeOpInterfacePattern : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; bool MatchAndRewrite(OpTy dim_op, PatternRewriter& rewriter) const override { - OpResult dim_value = dim_op.source().template dyn_cast(); - if (!dim_value) return false; - - auto shaped_type_op = - dim_value.owner() - ->dyn_cast(); - if (!shaped_type_op) return false; - - std::optional dim_index = dim_op.GetConstantIndex(); - if (!dim_index) return false; - - std::vector reified_result_shapes; - if (!shaped_type_op.InferSymbolicShape( - rewriter, shaped_type_op->operands(), reified_result_shapes)) - return false; - - if (reified_result_shapes.size() != shaped_type_op->num_results()) - return false; - - Value result_shape = reified_result_shapes[dim_value.index()]; - auto result_shape_type = result_shape.type().dyn_cast(); - auto shaped_type = result_shape_type.dyn_cast(); - if (!result_shape_type || !shaped_type.GetElementType().IsIntOrIndex()) - return false; - - // TODO(zhangbopd): BuildOrFold required. - std::vector indices; - indices.push_back(rewriter.Build(*dim_index).out()); - - Value new_value = - rewriter.Build(result_shape, indices).out(); - - if (!new_value.type().isa()) - new_value = - rewriter.Build(rewriter.index_type(), new_value) - .out(); - - rewriter.ReplaceOp(dim_op, {new_value}); return true; } }; @@ -349,19 +311,6 @@ bool ShapeComputationIRAnalysis::Run() { // Make sure only run once. if (initialized_) return false; initialized_ = true; - // auto build_shape_func = - // std::bind(&ShapeComputationIRAnalysis::BuildShapeOnOperation, - // this, - // std::placeholders::_1); - // if (!RunOnRegion(&(m_->region(0)), build_shape_func)) return false; - // auto apply_op_constraint_func = - // std::bind(&ShapeComputationIRAnalysis::ApplyOpConstraint, - // this, - // std::placeholders::_1); - // // TODO(zhangbopd): Delete the following 1 line and fix UT - // // `shape_optimization_test` - // return true; - // if (!RunOnRegion(&(m_->region(0)), apply_op_constraint_func)) return false; return true; } @@ -508,220 +457,81 @@ bool OptimizeShapeComputation(pir::ModuleOp m, PassPipelineRunner runner) { return true; } -void print_program(pir::ModuleOp m, std::string mgs) { +void PrintProgram(pir::ModuleOp m, std::string mgs) { std::ostringstream print_stream; print_stream << "\n\n"; m.program()->Print(print_stream); print_stream << "\n\n"; - VLOG(5) << "===================== " << mgs << "\n" << print_stream.str(); -} - -bool IsShapeSpecialOp(const pir::Operation& op) { - auto name = op.name(); - if (name == "pd_op.shape" || name == "cinn_op.slice") { - return true; - } - - return false; -} - -bool IsAllEqualUnaryOp(const pir::Operation& op) { - auto name = op.name(); - if (name == "pd_op.exp" || name == "pd_op.cast") { - return true; - } - - return false; -} - -void InferSymbolicShapeAllEqualUnary( - pir::Operation* op, pir::ShapeConstraintIRAnalysis* shape_analysis) { - auto operand_source = op->operand_source(0); - auto operand_source_id = pir::GetValueId(&operand_source); - auto rst = op->result(0); - auto rst_id = pir::GetValueId(&rst); - shape_analysis->value_to_valueshape_expr_[rst_id] = - shape_analysis->value_to_valueshape_expr_[operand_source_id]; -} - -bool IsAllEqualBinaryOp(const pir::Operation& op) { - auto name = op.name(); - if (name == "pd_op.subtract") { - return true; - } - - return false; -} - -void InferSymbolicShapeAllEqualBinary( - pir::Operation* op, pir::ShapeConstraintIRAnalysis* shape_analysis) { - auto operand_source = op->operand_source(0); - auto operand_source_id = pir::GetValueId(&operand_source); - auto rst = op->result(0); - auto rst_id = pir::GetValueId(&rst); - shape_analysis->value_to_valueshape_expr_[rst_id] = - shape_analysis->value_to_valueshape_expr_[operand_source_id]; -} - -void InferSymbolicShapePdShape(pir::Operation* op, - pir::ShapeConstraintIRAnalysis* shape_analysis) { - auto operand_source = op->operand_source(0); - auto operand_source_id = pir::GetValueId(&operand_source); - auto rst = op->result(0); - auto rst_id = pir::GetValueId(&rst); - std::pair, std::vector> value_shape; - - auto type = rst.type(); - auto tensor_type = type.dyn_cast(); - auto ddim_vec = common::vectorize(tensor_type.dims()); - for (auto dim : ddim_vec) { - std::string sym_name = ""; - if (dim == -1) { - sym_name = shape_analysis->GetNextSymName(); - } else { - sym_name = std::to_string(dim); - } - value_shape.first.emplace_back(sym_name); - } - - value_shape.second = - shape_analysis->value_to_valueshape_expr_[operand_source_id].first; - shape_analysis->value_to_valueshape_expr_[rst_id] = value_shape; -} - -void InferSymbolicShapeCinnSlice( - pir::Operation* op, pir::ShapeConstraintIRAnalysis* shape_analysis) { - auto operand_source = op->operand_source(0); - auto operand_source_id = pir::GetValueId(&operand_source); - auto rst = op->result(0); - auto rst_id = pir::GetValueId(&rst); - std::pair, std::vector> value_shape; - - auto type = rst.type(); - auto tensor_type = type.dyn_cast(); - auto ddim_vec = common::vectorize(tensor_type.dims()); - for (auto dim : ddim_vec) { - std::string sym_name = ""; - if (dim == -1) { - sym_name = shape_analysis->GetNextSymName(); - } else { - sym_name = std::to_string(dim); - } - value_shape.first.emplace_back(sym_name); - } - - auto attributes = op->attributes(); - - auto attr_starts = attributes["starts"].dyn_cast().AsVector(); - auto start = attr_starts[0].dyn_cast().data(); - - auto attr_ends = attributes["ends"].dyn_cast().AsVector(); - auto end = attr_ends[0].dyn_cast().data(); - - auto source_shape_vec = - shape_analysis->value_to_valueshape_expr_[operand_source_id].second; - for (int i = start; i < end; i++) { - value_shape.second.emplace_back(source_shape_vec[i]); - } - - shape_analysis->value_to_valueshape_expr_[rst_id] = value_shape; + VLOG(3) << "===================== " << mgs << " =====================\n" + << print_stream.str(); } -void InferSymbolicShapeBuiltinCombine( - pir::Operation* op, pir::ShapeConstraintIRAnalysis* shape_analysis) { - std::pair, std::vector> value_shape; - for (auto operand_source : op->operands_source()) { - auto operand_source_id = pir::GetValueId(&operand_source); - auto source_shape_vec = - shape_analysis->value_to_valueshape_expr_[operand_source_id].second; - for (int i = 0; i < source_shape_vec.size(); i++) { - value_shape.second.emplace_back(source_shape_vec[i]); - } - } - - auto rst = op->result(0); - auto rst_id = pir::GetValueId(&rst); - - shape_analysis->value_to_valueshape_expr_[rst_id] = value_shape; -} - -void InferSymbolicShapeStack(pir::Operation* op, - pir::ShapeConstraintIRAnalysis* shape_analysis) { - auto operand_source = op->operand_source(0); - auto operand_source_id = pir::GetValueId(&operand_source); - auto rst = op->result(0); - auto rst_id = pir::GetValueId(&rst); - std::pair, std::vector> value_shape; - - value_shape.second = - shape_analysis->value_to_valueshape_expr_[operand_source_id].second; - shape_analysis->value_to_valueshape_expr_[rst_id] = value_shape; -} - -void InferSymbolicShapeReshape(pir::Operation* op, - pir::ShapeConstraintIRAnalysis* shape_analysis) { - auto operand_source_1 = op->operand_source(1); - auto operand_source_1_id = pir::GetValueId(&operand_source_1); - auto rst = op->result(0); - auto rst_id = pir::GetValueId(&rst); - - std::pair, std::vector> value_shape; - - value_shape.first = - shape_analysis->value_to_valueshape_expr_[operand_source_1_id].second; - shape_analysis->value_to_valueshape_expr_[rst_id] = value_shape; -} - -void debug_print_op_info( +void DebugPrintOpInfo( pir::Operation* op, pir::ShapeConstraintIRAnalysis* shape_analysis = nullptr) { - VLOG(5) << op->name() << ", num_operands: " << op->num_operands(); - for (auto& rst : op->results()) { - auto type = rst.type(); - auto value_id = pir::GetValueId(&rst); + VLOG(0) << op->name() << ", num_operands: " << op->num_operands(); + for (auto& res : op->results()) { + auto value_id = pir::GetValueId(&res); std::ostringstream print_stream; - print_stream << ">>>> result(" << rst.index() << ") 's ID: " << value_id; - if (shape_analysis != nullptr) { - auto value_shape = shape_analysis->value_to_valueshape_expr_[value_id]; - print_stream << ", value_shape.first: ["; - for (auto str : value_shape.first) { - print_stream << str << ", "; + print_stream << ">>>> result(" << res.index() << ") 's ID: " << value_id; + if (shape_analysis != nullptr) { + auto shape_data = shape_analysis->value_id_to_shapeordata_[value_id]; + print_stream << ", ShapeOrData.shape: ["; + + for (auto str : shape_data.shape()) { + int64_t* i = std::get_if(&str); + std::string* s = std::get_if(&str); + if (i) { + print_stream << *i << ", "; + } else if (s) { + print_stream << *s << ", "; + } } - print_stream << "], second: ["; - for (auto str : value_shape.second) { - print_stream << str << ", "; + + print_stream << "], ShapeOrData.data: ["; + if (shape_data.data().has_value()) { + for (auto str : shape_data.data().value()) { + int64_t* i = std::get_if(&str); + std::string* s = std::get_if(&str); + if (i) { + print_stream << *i << ", "; + } else if (s) { + print_stream << *s << ", "; + } + } } print_stream << "]\n"; } - VLOG(5) << print_stream.str(); + VLOG(0) << print_stream.str(); } } -void InferSymExprForAllValues(pir::ModuleOp module_op) { - auto shape_analysis_mgr = pir::ShapeAnalysisManager::Instance(); - pir::ShapeConstraintIRAnalysis& shape_analysis = +void InferSymExprForAllValues(ModuleOp module_op) { + auto shape_analysis_mgr = ShapeAnalysisManager::Instance(); + ShapeConstraintIRAnalysis& shape_analysis = shape_analysis_mgr.Get(module_op.program()); for (int i = 0; i < module_op->num_regions(); i++) { for (auto& block : module_op->region(i)) { for (auto& op : block) { if (op.num_operands() == 0) { - // Need new syms for -1s - for (auto& rst : op.results()) { - auto value_id = pir::GetValueId(&rst); - std::pair, std::vector> - value_shape; - auto type = rst.type(); - auto tensor_type = type.dyn_cast(); - auto ddim_vec = common::vectorize(tensor_type.dims()); - for (auto dim : ddim_vec) { - std::string sym_name = ""; + for (auto& res : op.results()) { + auto value_id = pir::GetValueId(&res); + + std::vector dims = common::vectorize( + res.type().dyn_cast().dims()); + + std::vector shapes; + for (int64_t dim : dims) { + symbol::DimExpr dim_expr; if (dim == -1) { - sym_name = shape_analysis.GetNextSymName(); + symbol::DimExpr res_dim_expr(shape_analysis.GetNextSymName()); + dim_expr = res_dim_expr; } else { - sym_name = std::to_string(dim); + symbol::DimExpr res_dim_expr(dim); + dim_expr = res_dim_expr; } - value_shape.first.emplace_back(sym_name); + shapes.push_back(dim_expr); } if (op.name() == "pd_op.full_int_array") { @@ -730,28 +540,23 @@ void InferSymExprForAllValues(pir::ModuleOp module_op) { auto arr = attr.dyn_cast(); const auto& vec = arr.AsVector(); for (auto item : vec) { - auto i = item.dyn_cast(); - value_shape.second.emplace_back(std::to_string(i.data())); + int64_t i = item.dyn_cast().data(); + shapes.push_back(symbol::DimExpr(i)); } } - shape_analysis.value_to_valueshape_expr_[value_id] = value_shape; + symbol::ShapeOrDataDimExprs shape_data{shapes}; + shape_analysis.value_id_to_shapeordata_[value_id] = shape_data; + } + } else { + auto infer_symbolic_shape_interface = + op.dyn_cast(); + if (infer_symbolic_shape_interface) { + PADDLE_ENFORCE(infer_symbolic_shape_interface.InferSymbolicShape( + &shape_analysis)); } - } else if (IsAllEqualUnaryOp(op)) { - InferSymbolicShapeAllEqualUnary(&op, &shape_analysis); - } else if (IsAllEqualBinaryOp(op)) { - InferSymbolicShapeAllEqualBinary(&op, &shape_analysis); - } else if (op.name() == "pd_op.shape") { - InferSymbolicShapePdShape(&op, &shape_analysis); - } else if (op.name() == "cinn_op.slice") { - InferSymbolicShapeCinnSlice(&op, &shape_analysis); - } else if (op.name() == "builtin.combine") { - InferSymbolicShapeBuiltinCombine(&op, &shape_analysis); - } else if (op.name() == "pd_op.stack") { - InferSymbolicShapeStack(&op, &shape_analysis); - } else if (op.name() == "pd_op.reshape") { - InferSymbolicShapeReshape(&op, &shape_analysis); } - debug_print_op_info(&op, &shape_analysis); + + DebugPrintOpInfo(&op, &shape_analysis); } } } @@ -762,11 +567,11 @@ class ShapeOptimizationPass : public pir::Pass { ShapeOptimizationPass() : pir::Pass("shape_optimization_pass", 0) {} void Run(pir::Operation* op) override { - VLOG(5) << "===================== ShapeOptimizationPass Run start... " + VLOG(3) << "===================== ShapeOptimizationPass Run start... " "============================="; auto module_op = op->dyn_cast(); IR_ENFORCE(module_op, "ShapeOptimizationPass should run on module op."); - print_program(module_op, "Origin Program:"); + PrintProgram(module_op, "Origin Program"); InferSymExprForAllValues(module_op); MaterializeShapeComputation(module_op); @@ -777,7 +582,7 @@ class ShapeOptimizationPass : public pir::Pass { // if (!OptimizeShapeComputation(module_op, runner)) { // return; // } - VLOG(5) << "===================== ShapeOptimizationPass Run End. " + VLOG(3) << "===================== ShapeOptimizationPass Run End. " "============================="; } diff --git a/paddle/phi/api/yaml/ops.yaml b/paddle/phi/api/yaml/ops.yaml index de7c49250ea16..de4d700cdf80e 100644 --- a/paddle/phi/api/yaml/ops.yaml +++ b/paddle/phi/api/yaml/ops.yaml @@ -841,6 +841,7 @@ func : exp inplace : (x -> out) backward : exp_grad + interfaces : paddle::dialect::InferSymbolicShapeInterface - op : expand args : (Tensor x, IntArray shape = {}) @@ -2355,6 +2356,7 @@ shape_sr {selected_rows -> dense} data_transform: skip_transform : input + interfaces : paddle::dialect::InferSymbolicShapeInterface - op : shard_index args : (Tensor input, int index_num, int nshards, int shard_id, int ignore_value=-1) @@ -2538,6 +2540,7 @@ kernel : func : stack backward : stack_grad + interfaces : paddle::dialect::InferSymbolicShapeInterface - op : stanh args : (Tensor x, float scale_a=0.67f, float scale_b=1.7159f) diff --git a/paddle/pir/dialect/shape/utils/shape_utils.h b/paddle/pir/dialect/shape/utils/shape_utils.h index 717b05eb8fede..ac72c0bae88c7 100644 --- a/paddle/pir/dialect/shape/utils/shape_utils.h +++ b/paddle/pir/dialect/shape/utils/shape_utils.h @@ -76,11 +76,6 @@ class IR_API ShapeConstraintIRAnalysis : public ShapeAnalysis { Value rhs, std::vector rhs_dim_idxs) override; - std::unordered_map< - std::string, - std::pair, std::vector>> - value_to_valueshape_expr_; - inline const std::string GetNextSymName() { return "S" + std::to_string(next_sym_idx_++); } @@ -89,6 +84,9 @@ class IR_API ShapeConstraintIRAnalysis : public ShapeAnalysis { symbol::DimExprBuilder CreateDimExprBuilder() override; + std::unordered_map + value_id_to_shapeordata_; + private: // The operation this analysis runs on. ModuleOp m_; @@ -99,9 +97,6 @@ class IR_API ShapeConstraintIRAnalysis : public ShapeAnalysis { std::unordered_map> value_to_sym_dims_; - std::unordered_map - value_id_to_shapeordata; - int64_t next_sym_idx_ = 0; std::vector constraints_;