From fe38b6986fea9b2b39a3debda5b8eb85cd940e53 Mon Sep 17 00:00:00 2001 From: Hongqing-work <76149632+Hongqing-work@users.noreply.github.com> Date: Tue, 6 Aug 2024 11:15:36 +0800 Subject: [PATCH] [CINN]Fix InferSymbolicShape for assign_value and reshape (#67039) --- .../infer_symbolic_shape/nullary_infer_sym.cc | 29 ++++++++++++------- .../infer_symbolic_shape/unary_infer_sym.cc | 3 +- 2 files changed, 21 insertions(+), 11 deletions(-) diff --git a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/nullary_infer_sym.cc b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/nullary_infer_sym.cc index 611dbccf0d68b5..d82b5bb6fabdb0 100644 --- a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/nullary_infer_sym.cc +++ b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/nullary_infer_sym.cc @@ -60,18 +60,27 @@ bool AssignValueOpInferSymbolicShape( sym_dims.emplace_back(symbol::DimExpr(static_cast(dim))); } - const auto &attributes = op->attributes(); + bool result_is_int_type = [&]() { + const auto &dtype = + op->result(0).type().dyn_cast().dtype(); + return dtype.isa() || dtype.isa(); + }(); + std::vector values; - for (size_t i = 0; - i < attributes.at("values").dyn_cast().size(); - i++) { - values.push_back(attributes.at("values") - .dyn_cast() - .at(i) - .dyn_cast() - .data() - .to()); + if (result_is_int_type) { + const auto &attributes = op->attributes(); + for (size_t i = 0; + i < attributes.at("values").dyn_cast().size(); + i++) { + values.push_back(attributes.at("values") + .dyn_cast() + .at(i) + .dyn_cast() + .data() + .to()); + } } + if (values.size() > 0 && sym_dims.size() <= 1) { std::vector data; for (const auto &value : values) { diff --git a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/unary_infer_sym.cc b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/unary_infer_sym.cc index 7f28e881144a22..eddc3c68ee872f 100644 --- a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/unary_infer_sym.cc +++ b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/unary_infer_sym.cc @@ -1294,7 +1294,8 @@ bool ReshapeOpInferSymbolicShape( const std::vector out_dims = [&] { const auto &original_shape = infer_context->GetShapeOrDataForValue(op->operand_source(0)).shape(); - ExprVec target_shape = details::GetExprVecFromData(shape_dim_expr); + ExprVec target_shape = + details::GetOrCreateExprVecFromData(shape_dim_expr, infer_context); // replace '0' with original shape for (size_t i = 0; i < target_shape.size(); i++) {