Skip to content

Commit

Permalink
[CINN]Fix InferSymbolicShape for assign_value and reshape (#67039)
Browse files Browse the repository at this point in the history
  • Loading branch information
Hongqing-work authored and pull[bot] committed Sep 2, 2024
1 parent e62cb61 commit fe38b69
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 11 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -60,18 +60,27 @@ bool AssignValueOpInferSymbolicShape(
sym_dims.emplace_back(symbol::DimExpr(static_cast<int64_t>(dim)));
}

const auto &attributes = op->attributes();
bool result_is_int_type = [&]() {
const auto &dtype =
op->result(0).type().dyn_cast<pir::DenseTensorType>().dtype();
return dtype.isa<pir::Int32Type>() || dtype.isa<pir::Int64Type>();
}();

std::vector<int64_t> values;
for (size_t i = 0;
i < attributes.at("values").dyn_cast<pir::ArrayAttribute>().size();
i++) {
values.push_back(attributes.at("values")
.dyn_cast<pir::ArrayAttribute>()
.at(i)
.dyn_cast<paddle::dialect::ScalarAttribute>()
.data()
.to<int64_t>());
if (result_is_int_type) {
const auto &attributes = op->attributes();
for (size_t i = 0;
i < attributes.at("values").dyn_cast<pir::ArrayAttribute>().size();
i++) {
values.push_back(attributes.at("values")
.dyn_cast<pir::ArrayAttribute>()
.at(i)
.dyn_cast<paddle::dialect::ScalarAttribute>()
.data()
.to<int64_t>());
}
}

if (values.size() > 0 && sym_dims.size() <= 1) {
std::vector<symbol::DimExpr> data;
for (const auto &value : values) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1294,7 +1294,8 @@ bool ReshapeOpInferSymbolicShape(
const std::vector<symbol::DimExpr> out_dims = [&] {
const auto &original_shape =
infer_context->GetShapeOrDataForValue(op->operand_source(0)).shape();
ExprVec target_shape = details::GetExprVecFromData(shape_dim_expr);
ExprVec target_shape =
details::GetOrCreateExprVecFromData(shape_dim_expr, infer_context);

// replace '0' with original shape
for (size_t i = 0; i < target_shape.size(); i++) {
Expand Down

0 comments on commit fe38b69

Please sign in to comment.