Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use DimExpr and change InferSymbolicShapeInterface #60371

Merged
merged 6 commits into from
Dec 28, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions paddle/cinn/hlir/dialect/operator/ir/ops.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@
func : SliceRawInferMeta
kernel :
func : slice
interfaces : paddle::dialect::InferSymbolicShapeInterface

- op : uniform_random
args : (int64_t[] shape, float min, float max, int seed, DataType dtype, int diag_num = 0, int diag_step=0, float diag_val=1.0)
Expand Down
5 changes: 3 additions & 2 deletions paddle/fluid/inference/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -64,8 +64,9 @@ set(KERNEL_LIST

# shared inference library deps
list(REMOVE_DUPLICATES fluid_modules)
#windows GPU static library over the limit, so not create_static_lib, and cc_library is dummy
if(WIN32 AND WITH_GPU)
# windows static library(both CPU and GPU)over the limit, so no longer create_static_lib,
# and cc_library is dummy
if(WIN32)
cc_library(paddle_inference DEPS ${fluid_modules} ${STATIC_INFERENCE_API}
${utils_modules})
else()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,9 @@
# limitations under the License.

OP_GET_KERNEL_TYPE_FOR_VAR_TEMPLATE = """
bool {op_name}::InferSymbolicShape(pir::Builder &builder,
const std::vector<pir::OpOperand> &operands,
std::vector<pir::Value> &reified_return_shapes) {{
bool {op_name}::InferSymbolicShape(pir::ShapeConstraintIRAnalysis* shape_analysis) {{
VLOG(4) << "Infer symbolic shape for op: {op_name}";
return {op_name}InferSymbolicShape(builder, operands, reified_return_shapes);
return {op_name}InferSymbolicShape(this->operation(), shape_analysis);
}}
"""

Expand Down
4 changes: 1 addition & 3 deletions paddle/fluid/pir/dialect/op_generator/op_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,9 +131,7 @@ class {TEST_API} {op_name} : public pir::Op<{op_name}{interfaces}{traits}> {{
"""

infer_symbolic_shape_template = """
static bool InferSymbolicShape(pir::Builder &builder,
const std::vector<pir::OpOperand> &operands,
std::vector<pir::Value> &reified_return_shapes);
bool InferSymbolicShape(pir::ShapeConstraintIRAnalysis* shape_analysis);
"""

# =====================================
Expand Down
273 changes: 162 additions & 111 deletions paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,141 +13,192 @@
// limitations under the License.

#include "paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape.h"
#include "paddle/pir/core/builtin_attribute.h"
#include "paddle/pir/core/builtin_type.h"
#include "paddle/pir/dialect/shape/ir/shape_op.h"

namespace paddle::dialect {

bool InferSymbolicShapeInterface::InferSymbolicShape(
pir::Builder &builder,
const std::vector<pir::OpOperand> &operands,
std::vector<pir::Value> &reified_return_shapes) {
return impl_->infer_symbolic_shapes(
operation(), builder, operands, reified_return_shapes);
pir::ShapeConstraintIRAnalysis *shape_analysis) {
return impl_->infer_symbolic_shapes(operation(), shape_analysis);
}
} // namespace paddle::dialect

namespace paddle::dialect {

namespace {

bool DeriveShapeFromOperand(pir::Builder *builder,
pir::Value operand,
std::vector<pir::Value> *reified_return_shapes) {
auto shaped_type = operand.type().dyn_cast<pir::ShapedTypeInterface>();
if (!shaped_type) return false;
reified_return_shapes->assign(
{builder->Build<pir::shape::ShapeOfOp>(operand).result(0)});
bool InferSymbolicShapeAllEqualUnary(
pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) {
pir::Value operand_source = op->operand_source(0);
std::string operand_source_id = pir::GetValueId(&operand_source);
pir::OpResult res = op->result(0);
std::string res_id = pir::GetValueId(&res);
shape_analysis->value_id_to_shapeordata_[res_id] =
shape_analysis->value_id_to_shapeordata_[operand_source_id];
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

正式代码里最好不要直接操作成员变量了,需要使用封装接口操作

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

收到

return true;
}

// Returns a new scalar integer value having type `type`.
// Here `type` must be an integer or index type.
pir::Value MaybeCastTo(pir::Builder &builder, // NOLINT
pir::Value value,
pir::Type type) {
if (type == value.type()) return value;
// if (!type.IsIndex() && !value.type().IsIndex()) {
// Value casted =
// builder.Build<shape::IndexCastOp>(builder.index_type(), value)
// .result(0);
// return builder.Build<shape::IndexCastOp>(type, casted).result(0);
// }
// return builder.Build<shape::IndexCastOp>(type, value).result(0);
bool InferSymbolicShapeAllEqualBinary(
pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) {
pir::Value operand_source = op->operand_source(0);
std::string operand_source_id = pir::GetValueId(&operand_source);
pir::OpResult res = op->result(0);
std::string res_id = pir::GetValueId(&res);
shape_analysis->value_id_to_shapeordata_[res_id] =
shape_analysis->value_id_to_shapeordata_[operand_source_id];
Comment on lines +46 to +50
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

value可以直接用作map的key吗?看着用value id会多一些处理

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

可以

return true;
}

} // namespace

bool AbsOpInferSymbolicShape(
pir::Builder &builder, // NOLINT
const std::vector<pir::OpOperand> &operands,
std::vector<pir::Value> &reified_return_shapes) { // NOLINT
return DeriveShapeFromOperand(
&builder, operands.front().source(), &reified_return_shapes);
}

bool Abs_OpInferSymbolicShape(
pir::Builder &builder, // NOLINT
const std::vector<pir::OpOperand> &operands,
std::vector<pir::Value> &reified_return_shapes) { // NOLINT
return DeriveShapeFromOperand(
&builder, operands.front().source(), &reified_return_shapes);
}

bool TransposeOpInferSymbolicShape(
pir::Builder &builder, // NOLINT
const std::vector<pir::OpOperand> &operands,
std::vector<pir::Value> &reified_return_shapes) { // NOLINT
// auto operand_type = operands[0].type().dyn_cast<DenseTensorType>();
// // Currently not support unranked type.
// if (!operand_type) return false;
// std::vector<int64_t> permutation = this->permutation();
// std::vector<Value> shape_values(permutation.size());
// Type shape_scalar_type = builder.index_type();
// auto to_shape_scalar_type = [&](Value v) {
// return MaybeCastTo(builder, v, shape_scalar_type);
// };
// auto shaped_type = operand_type.dyn_cast<ShapedTypeInterface>();
// auto shape_vector = shaped_type.GetDyShape();
// for (auto [idx, element] = std::tuple{0, shape_vector.begin()};
// element != shape_vector.end();
// ++idx, ++element) {
// auto it = std::find(permutation.begin(), permutation.end(), idx);
// // TODO(zhangbopd): Need BuildOrFold
// Value value_dim = to_shape_scalar_type(
// builder.Build<shape::TensorDimOp>(operands[0].source(),
// idx).result(0));
// shape_values[std::distance(permutation.begin(), it)] = value_dim;
// }
// Value output_shape =
// builder.Build<shape::FromElementsOp>(shape_values).result(0);
// reified_return_shapes.push_back(output_shape);
bool AbsOpInferSymbolicShape(pir::Operation *op,
pir::ShapeConstraintIRAnalysis *shape_analysis) {
return InferSymbolicShapeAllEqualUnary(op, shape_analysis);
}

bool Abs_OpInferSymbolicShape(pir::Operation *op,
pir::ShapeConstraintIRAnalysis *shape_analysis) {
return InferSymbolicShapeAllEqualUnary(op, shape_analysis);
}

bool CastOpInferSymbolicShape(pir::Operation *op,
pir::ShapeConstraintIRAnalysis *shape_analysis) {
return InferSymbolicShapeAllEqualUnary(op, shape_analysis);
}

bool Cast_OpInferSymbolicShape(pir::Operation *op,
pir::ShapeConstraintIRAnalysis *shape_analysis) {
return InferSymbolicShapeAllEqualUnary(op, shape_analysis);
}

bool ExpOpInferSymbolicShape(pir::Operation *op,
pir::ShapeConstraintIRAnalysis *shape_analysis) {
return InferSymbolicShapeAllEqualUnary(op, shape_analysis);
}

bool Exp_OpInferSymbolicShape(pir::Operation *op,
pir::ShapeConstraintIRAnalysis *shape_analysis) {
return InferSymbolicShapeAllEqualUnary(op, shape_analysis);
}

bool SubtractOpInferSymbolicShape(
pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) {
return InferSymbolicShapeAllEqualBinary(op, shape_analysis);
}

bool Subtract_OpInferSymbolicShape(
pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) {
return InferSymbolicShapeAllEqualBinary(op, shape_analysis);
}

bool ShapeOpInferSymbolicShape(pir::Operation *op,
pir::ShapeConstraintIRAnalysis *shape_analysis) {
pir::Value operand_source = op->operand_source(0);
std::string operand_source_id = pir::GetValueId(&operand_source);
pir::OpResult res = op->result(0);
std::string res_id = pir::GetValueId(&res);

std::vector<int64_t> dims =
common::vectorize(res.type().dyn_cast<pir::DenseTensorType>().dims());

std::vector<symbol::DimExpr> shapes;
for (int64_t dim : dims) {
symbol::DimExpr dim_expr;
if (dim == -1) {
symbol::DimExpr res_dim_expr(shape_analysis->GetNextSymName());
dim_expr = res_dim_expr;
} else {
symbol::DimExpr res_dim_expr(dim);
dim_expr = res_dim_expr;
}
shapes.push_back(dim_expr);
}

symbol::ShapeOrDataDimExprs shape_data{shapes};
shape_analysis->value_id_to_shapeordata_[res_id] = shape_data;
return true;
}

bool ConcatOpInferSymbolicShape(
pir::Builder &builder, // NOLINT
const std::vector<pir::OpOperand> &operands,
std::vector<pir::Value> &reified_return_shapes) { // NOLINT
// std::vector<Value> inputs = {x()};
// auto operand_type = inputs[0].type().dyn_cast<DenseTensorType>();
// // Currently not support unranked type.
// if (!operand_type) return false;
// Type shapeScalarType = builder.index_type();
// auto to_shape_scalar_type = [&](Value v) {
// return MaybeCastTo(builder, v, shapeScalarType);
// };
// std::vector<std::vector<Value>> all_shape_values;
// for (size_t inputId = 0; inputId < inputs.size(); ++inputId) {
// Value operand = inputs[inputId];
// auto operand_type = operand.type().dyn_cast<DenseTensorType>();
// if (!operand_type) return false;
// std::vector<Value> shape_values;
// auto shaped_type = operand_type.dyn_cast<ShapedTypeInterface>();
// auto shape_vector = shaped_type.GetDyShape();
// for (auto [idx, element] = std::tuple{0, shape_vector.begin()};
// element != shape_vector.end();
// ++idx, ++element) {
// Value value_dim = to_shape_scalar_type(
// builder.Build<shape::TensorDimOp>(operand, idx).result(0));
// shape_values.push_back(value_dim);
// }
// all_shape_values.emplace_back(std::move(shape_values));
// }
// [[maybe_unused]] int axis = this->dimension();
// auto &shape_values = all_shape_values[0];
// for (size_t vecId = 1; vecId < all_shape_values.size(); ++vecId) {
// auto &otherShapeValues = all_shape_values[vecId];
// if (otherShapeValues.size() != shape_values.size()) return false;
// TODO(zhangbopd): AddIOp
// shape_values[axis] =
// builder.Build<arith::AddIOp>(shape_values[axis],
// otherShapeValues[axis]);
// }
// Value output_shape =
// builder.Build<shape::FromElementsOp>(shape_values).result(0);
// reified_return_shapes.push_back(output_shape);
bool ShapeSrOpInferSymbolicShape(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shape"Sr"Op 中间的Sr是什么含义?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

selected_rows 版本的算子,ShapeSrOp 和 ShapeSrOp 共用了一个 yaml 配置

pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) {
return ShapeOpInferSymbolicShape(op, shape_analysis);
}

bool StackOpInferSymbolicShape(pir::Operation *op,
pir::ShapeConstraintIRAnalysis *shape_analysis) {
pir::Value operand_source = op->operand_source(0);
std::string operand_source_id = pir::GetValueId(&operand_source);
pir::OpResult res = op->result(0);
std::string res_id = pir::GetValueId(&res);

symbol::ShapeOrDataDimExprs shape_data;
shape_data = shape_analysis->value_id_to_shapeordata_[operand_source_id];
shape_analysis->value_id_to_shapeordata_[res_id] = shape_data;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

#60146 这个PR合入之后,可以增加一些调试信息输出,便于以后调试使用

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

好的,这个PR里 重写了 DebugPrintOpInfo,目前也支持 符号和 int_64 两种类型的 Dim_Expr 打印,后续会改用新的接口

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

#60146 已合入

return true;
}

bool ReshapeOpInferSymbolicShape(
pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) {
pir::Value operand_source_1 = op->operand_source(1);
std::string operand_source_1_id = pir::GetValueId(&operand_source_1);
pir::OpResult res = op->result(0);
std::string res_id = pir::GetValueId(&res);

symbol::ShapeOrDataDimExprs shape_data;

shape_data = shape_analysis->value_id_to_shapeordata_[operand_source_1_id];
shape_analysis->value_id_to_shapeordata_[res_id] = shape_data;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里的shapedata数据是不是需要交换下?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里应该后面要用到宏宇PR的一个接口,下个PR改,这里data应该还是空的

return true;
}

bool Reshape_OpInferSymbolicShape(
pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) {
return ReshapeOpInferSymbolicShape(op, shape_analysis);
}

} // namespace paddle::dialect
namespace cinn::dialect {

bool SliceOpInferSymbolicShape(pir::Operation *op,
pir::ShapeConstraintIRAnalysis *shape_analysis) {
pir::Value operand_source = op->operand_source(0);
std::string operand_source_id = pir::GetValueId(&operand_source);
pir::OpResult res = op->result(0);
std::string res_id = pir::GetValueId(&res);

std::vector<int64_t> dims =
common::vectorize(res.type().dyn_cast<pir::DenseTensorType>().dims());

std::vector<symbol::DimExpr> shapes;
for (int64_t dim : dims) {
symbol::DimExpr dim_expr;
if (dim == -1) {
symbol::DimExpr res_dim_expr(shape_analysis->GetNextSymName());
dim_expr = res_dim_expr;
} else {
symbol::DimExpr res_dim_expr(dim);
dim_expr = res_dim_expr;
}
shapes.push_back(dim_expr);
}

// pir::AttributeMap attributes = op->attributes();

// auto attr_starts =
// attributes["starts"].dyn_cast<pir::ArrayAttribute>().AsVector();
// auto start = attr_starts[0].dyn_cast<pir::Int64Attribute>().data();

// auto attr_ends =
// attributes["ends"].dyn_cast<pir::ArrayAttribute>().AsVector();
// auto end = attr_ends[0].dyn_cast<pir::Int64Attribute>().data();

symbol::ShapeOrDataDimExprs shape_data{shapes};
shape_analysis->value_id_to_shapeordata_[res_id] = shape_data;
return true;
}

} // namespace cinn::dialect

IR_DEFINE_EXPLICIT_TYPE_ID(paddle::dialect::InferSymbolicShapeInterface)
Loading