-
Notifications
You must be signed in to change notification settings - Fork 5.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Use DimExpr and change InferSymbolicShapeInterface #60371
Changes from all commits
aced1f7
0d50945
0c39f1a
dc61668
ab4790e
9c28f50
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -13,141 +13,192 @@ | |
// limitations under the License. | ||
|
||
#include "paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape.h" | ||
#include "paddle/pir/core/builtin_attribute.h" | ||
#include "paddle/pir/core/builtin_type.h" | ||
#include "paddle/pir/dialect/shape/ir/shape_op.h" | ||
|
||
namespace paddle::dialect { | ||
|
||
bool InferSymbolicShapeInterface::InferSymbolicShape( | ||
pir::Builder &builder, | ||
const std::vector<pir::OpOperand> &operands, | ||
std::vector<pir::Value> &reified_return_shapes) { | ||
return impl_->infer_symbolic_shapes( | ||
operation(), builder, operands, reified_return_shapes); | ||
pir::ShapeConstraintIRAnalysis *shape_analysis) { | ||
return impl_->infer_symbolic_shapes(operation(), shape_analysis); | ||
} | ||
} // namespace paddle::dialect | ||
|
||
namespace paddle::dialect { | ||
|
||
namespace { | ||
|
||
bool DeriveShapeFromOperand(pir::Builder *builder, | ||
pir::Value operand, | ||
std::vector<pir::Value> *reified_return_shapes) { | ||
auto shaped_type = operand.type().dyn_cast<pir::ShapedTypeInterface>(); | ||
if (!shaped_type) return false; | ||
reified_return_shapes->assign( | ||
{builder->Build<pir::shape::ShapeOfOp>(operand).result(0)}); | ||
bool InferSymbolicShapeAllEqualUnary( | ||
pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) { | ||
pir::Value operand_source = op->operand_source(0); | ||
std::string operand_source_id = pir::GetValueId(&operand_source); | ||
pir::OpResult res = op->result(0); | ||
std::string res_id = pir::GetValueId(&res); | ||
shape_analysis->value_id_to_shapeordata_[res_id] = | ||
shape_analysis->value_id_to_shapeordata_[operand_source_id]; | ||
return true; | ||
} | ||
|
||
// Returns a new scalar integer value having type `type`. | ||
// Here `type` must be an integer or index type. | ||
pir::Value MaybeCastTo(pir::Builder &builder, // NOLINT | ||
pir::Value value, | ||
pir::Type type) { | ||
if (type == value.type()) return value; | ||
// if (!type.IsIndex() && !value.type().IsIndex()) { | ||
// Value casted = | ||
// builder.Build<shape::IndexCastOp>(builder.index_type(), value) | ||
// .result(0); | ||
// return builder.Build<shape::IndexCastOp>(type, casted).result(0); | ||
// } | ||
// return builder.Build<shape::IndexCastOp>(type, value).result(0); | ||
bool InferSymbolicShapeAllEqualBinary( | ||
pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) { | ||
pir::Value operand_source = op->operand_source(0); | ||
std::string operand_source_id = pir::GetValueId(&operand_source); | ||
pir::OpResult res = op->result(0); | ||
std::string res_id = pir::GetValueId(&res); | ||
shape_analysis->value_id_to_shapeordata_[res_id] = | ||
shape_analysis->value_id_to_shapeordata_[operand_source_id]; | ||
Comment on lines
+46
to
+50
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. value可以直接用作map的key吗?看着用value id会多一些处理 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 可以 |
||
return true; | ||
} | ||
|
||
} // namespace | ||
|
||
bool AbsOpInferSymbolicShape( | ||
pir::Builder &builder, // NOLINT | ||
const std::vector<pir::OpOperand> &operands, | ||
std::vector<pir::Value> &reified_return_shapes) { // NOLINT | ||
return DeriveShapeFromOperand( | ||
&builder, operands.front().source(), &reified_return_shapes); | ||
} | ||
|
||
bool Abs_OpInferSymbolicShape( | ||
pir::Builder &builder, // NOLINT | ||
const std::vector<pir::OpOperand> &operands, | ||
std::vector<pir::Value> &reified_return_shapes) { // NOLINT | ||
return DeriveShapeFromOperand( | ||
&builder, operands.front().source(), &reified_return_shapes); | ||
} | ||
|
||
bool TransposeOpInferSymbolicShape( | ||
pir::Builder &builder, // NOLINT | ||
const std::vector<pir::OpOperand> &operands, | ||
std::vector<pir::Value> &reified_return_shapes) { // NOLINT | ||
// auto operand_type = operands[0].type().dyn_cast<DenseTensorType>(); | ||
// // Currently not support unranked type. | ||
// if (!operand_type) return false; | ||
// std::vector<int64_t> permutation = this->permutation(); | ||
// std::vector<Value> shape_values(permutation.size()); | ||
// Type shape_scalar_type = builder.index_type(); | ||
// auto to_shape_scalar_type = [&](Value v) { | ||
// return MaybeCastTo(builder, v, shape_scalar_type); | ||
// }; | ||
// auto shaped_type = operand_type.dyn_cast<ShapedTypeInterface>(); | ||
// auto shape_vector = shaped_type.GetDyShape(); | ||
// for (auto [idx, element] = std::tuple{0, shape_vector.begin()}; | ||
// element != shape_vector.end(); | ||
// ++idx, ++element) { | ||
// auto it = std::find(permutation.begin(), permutation.end(), idx); | ||
// // TODO(zhangbopd): Need BuildOrFold | ||
// Value value_dim = to_shape_scalar_type( | ||
// builder.Build<shape::TensorDimOp>(operands[0].source(), | ||
// idx).result(0)); | ||
// shape_values[std::distance(permutation.begin(), it)] = value_dim; | ||
// } | ||
// Value output_shape = | ||
// builder.Build<shape::FromElementsOp>(shape_values).result(0); | ||
// reified_return_shapes.push_back(output_shape); | ||
bool AbsOpInferSymbolicShape(pir::Operation *op, | ||
pir::ShapeConstraintIRAnalysis *shape_analysis) { | ||
return InferSymbolicShapeAllEqualUnary(op, shape_analysis); | ||
} | ||
|
||
bool Abs_OpInferSymbolicShape(pir::Operation *op, | ||
pir::ShapeConstraintIRAnalysis *shape_analysis) { | ||
return InferSymbolicShapeAllEqualUnary(op, shape_analysis); | ||
} | ||
|
||
bool CastOpInferSymbolicShape(pir::Operation *op, | ||
pir::ShapeConstraintIRAnalysis *shape_analysis) { | ||
return InferSymbolicShapeAllEqualUnary(op, shape_analysis); | ||
} | ||
|
||
bool Cast_OpInferSymbolicShape(pir::Operation *op, | ||
pir::ShapeConstraintIRAnalysis *shape_analysis) { | ||
return InferSymbolicShapeAllEqualUnary(op, shape_analysis); | ||
} | ||
|
||
bool ExpOpInferSymbolicShape(pir::Operation *op, | ||
pir::ShapeConstraintIRAnalysis *shape_analysis) { | ||
return InferSymbolicShapeAllEqualUnary(op, shape_analysis); | ||
} | ||
|
||
bool Exp_OpInferSymbolicShape(pir::Operation *op, | ||
pir::ShapeConstraintIRAnalysis *shape_analysis) { | ||
return InferSymbolicShapeAllEqualUnary(op, shape_analysis); | ||
} | ||
|
||
bool SubtractOpInferSymbolicShape( | ||
pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) { | ||
return InferSymbolicShapeAllEqualBinary(op, shape_analysis); | ||
} | ||
|
||
bool Subtract_OpInferSymbolicShape( | ||
pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) { | ||
return InferSymbolicShapeAllEqualBinary(op, shape_analysis); | ||
} | ||
|
||
bool ShapeOpInferSymbolicShape(pir::Operation *op, | ||
pir::ShapeConstraintIRAnalysis *shape_analysis) { | ||
pir::Value operand_source = op->operand_source(0); | ||
std::string operand_source_id = pir::GetValueId(&operand_source); | ||
pir::OpResult res = op->result(0); | ||
std::string res_id = pir::GetValueId(&res); | ||
|
||
std::vector<int64_t> dims = | ||
common::vectorize(res.type().dyn_cast<pir::DenseTensorType>().dims()); | ||
|
||
std::vector<symbol::DimExpr> shapes; | ||
for (int64_t dim : dims) { | ||
symbol::DimExpr dim_expr; | ||
if (dim == -1) { | ||
symbol::DimExpr res_dim_expr(shape_analysis->GetNextSymName()); | ||
dim_expr = res_dim_expr; | ||
} else { | ||
symbol::DimExpr res_dim_expr(dim); | ||
dim_expr = res_dim_expr; | ||
} | ||
shapes.push_back(dim_expr); | ||
} | ||
|
||
symbol::ShapeOrDataDimExprs shape_data{shapes}; | ||
shape_analysis->value_id_to_shapeordata_[res_id] = shape_data; | ||
return true; | ||
} | ||
|
||
bool ConcatOpInferSymbolicShape( | ||
pir::Builder &builder, // NOLINT | ||
const std::vector<pir::OpOperand> &operands, | ||
std::vector<pir::Value> &reified_return_shapes) { // NOLINT | ||
// std::vector<Value> inputs = {x()}; | ||
// auto operand_type = inputs[0].type().dyn_cast<DenseTensorType>(); | ||
// // Currently not support unranked type. | ||
// if (!operand_type) return false; | ||
// Type shapeScalarType = builder.index_type(); | ||
// auto to_shape_scalar_type = [&](Value v) { | ||
// return MaybeCastTo(builder, v, shapeScalarType); | ||
// }; | ||
// std::vector<std::vector<Value>> all_shape_values; | ||
// for (size_t inputId = 0; inputId < inputs.size(); ++inputId) { | ||
// Value operand = inputs[inputId]; | ||
// auto operand_type = operand.type().dyn_cast<DenseTensorType>(); | ||
// if (!operand_type) return false; | ||
// std::vector<Value> shape_values; | ||
// auto shaped_type = operand_type.dyn_cast<ShapedTypeInterface>(); | ||
// auto shape_vector = shaped_type.GetDyShape(); | ||
// for (auto [idx, element] = std::tuple{0, shape_vector.begin()}; | ||
// element != shape_vector.end(); | ||
// ++idx, ++element) { | ||
// Value value_dim = to_shape_scalar_type( | ||
// builder.Build<shape::TensorDimOp>(operand, idx).result(0)); | ||
// shape_values.push_back(value_dim); | ||
// } | ||
// all_shape_values.emplace_back(std::move(shape_values)); | ||
// } | ||
// [[maybe_unused]] int axis = this->dimension(); | ||
// auto &shape_values = all_shape_values[0]; | ||
// for (size_t vecId = 1; vecId < all_shape_values.size(); ++vecId) { | ||
// auto &otherShapeValues = all_shape_values[vecId]; | ||
// if (otherShapeValues.size() != shape_values.size()) return false; | ||
// TODO(zhangbopd): AddIOp | ||
// shape_values[axis] = | ||
// builder.Build<arith::AddIOp>(shape_values[axis], | ||
// otherShapeValues[axis]); | ||
// } | ||
// Value output_shape = | ||
// builder.Build<shape::FromElementsOp>(shape_values).result(0); | ||
// reified_return_shapes.push_back(output_shape); | ||
bool ShapeSrOpInferSymbolicShape( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Shape"Sr"Op 中间的Sr是什么含义? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. selected_rows 版本的算子,ShapeSrOp 和 ShapeSrOp 共用了一个 yaml 配置 |
||
pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) { | ||
return ShapeOpInferSymbolicShape(op, shape_analysis); | ||
} | ||
|
||
bool StackOpInferSymbolicShape(pir::Operation *op, | ||
pir::ShapeConstraintIRAnalysis *shape_analysis) { | ||
pir::Value operand_source = op->operand_source(0); | ||
std::string operand_source_id = pir::GetValueId(&operand_source); | ||
pir::OpResult res = op->result(0); | ||
std::string res_id = pir::GetValueId(&res); | ||
|
||
symbol::ShapeOrDataDimExprs shape_data; | ||
shape_data = shape_analysis->value_id_to_shapeordata_[operand_source_id]; | ||
shape_analysis->value_id_to_shapeordata_[res_id] = shape_data; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. #60146 这个PR合入之后,可以增加一些调试信息输出,便于以后调试使用 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 好的,这个PR里 重写了 DebugPrintOpInfo,目前也支持 符号和 int_64 两种类型的 Dim_Expr 打印,后续会改用新的接口 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. #60146 已合入 |
||
return true; | ||
} | ||
|
||
bool ReshapeOpInferSymbolicShape( | ||
pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) { | ||
pir::Value operand_source_1 = op->operand_source(1); | ||
std::string operand_source_1_id = pir::GetValueId(&operand_source_1); | ||
pir::OpResult res = op->result(0); | ||
std::string res_id = pir::GetValueId(&res); | ||
|
||
symbol::ShapeOrDataDimExprs shape_data; | ||
|
||
shape_data = shape_analysis->value_id_to_shapeordata_[operand_source_1_id]; | ||
shape_analysis->value_id_to_shapeordata_[res_id] = shape_data; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这里的 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这里应该后面要用到宏宇PR的一个接口,下个PR改,这里data应该还是空的 |
||
return true; | ||
} | ||
|
||
bool Reshape_OpInferSymbolicShape( | ||
pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) { | ||
return ReshapeOpInferSymbolicShape(op, shape_analysis); | ||
} | ||
|
||
} // namespace paddle::dialect | ||
namespace cinn::dialect { | ||
|
||
bool SliceOpInferSymbolicShape(pir::Operation *op, | ||
pir::ShapeConstraintIRAnalysis *shape_analysis) { | ||
pir::Value operand_source = op->operand_source(0); | ||
std::string operand_source_id = pir::GetValueId(&operand_source); | ||
pir::OpResult res = op->result(0); | ||
std::string res_id = pir::GetValueId(&res); | ||
|
||
std::vector<int64_t> dims = | ||
common::vectorize(res.type().dyn_cast<pir::DenseTensorType>().dims()); | ||
|
||
std::vector<symbol::DimExpr> shapes; | ||
for (int64_t dim : dims) { | ||
symbol::DimExpr dim_expr; | ||
if (dim == -1) { | ||
symbol::DimExpr res_dim_expr(shape_analysis->GetNextSymName()); | ||
dim_expr = res_dim_expr; | ||
} else { | ||
symbol::DimExpr res_dim_expr(dim); | ||
dim_expr = res_dim_expr; | ||
} | ||
shapes.push_back(dim_expr); | ||
} | ||
|
||
// pir::AttributeMap attributes = op->attributes(); | ||
|
||
// auto attr_starts = | ||
// attributes["starts"].dyn_cast<pir::ArrayAttribute>().AsVector(); | ||
// auto start = attr_starts[0].dyn_cast<pir::Int64Attribute>().data(); | ||
|
||
// auto attr_ends = | ||
// attributes["ends"].dyn_cast<pir::ArrayAttribute>().AsVector(); | ||
// auto end = attr_ends[0].dyn_cast<pir::Int64Attribute>().data(); | ||
|
||
symbol::ShapeOrDataDimExprs shape_data{shapes}; | ||
shape_analysis->value_id_to_shapeordata_[res_id] = shape_data; | ||
return true; | ||
} | ||
|
||
} // namespace cinn::dialect | ||
|
||
IR_DEFINE_EXPLICIT_TYPE_ID(paddle::dialect::InferSymbolicShapeInterface) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
正式代码里最好不要直接操作成员变量了,需要使用封装接口操作
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
收到