Skip to content

Commit

Permalink
[CINN Support 0D-Tensor] CINN hack squeeze2 with trick temporarily (P…
Browse files Browse the repository at this point in the history
  • Loading branch information
jiahy0825 authored and zhwesky2010 committed May 8, 2023
1 parent 904dc64 commit b846ec5
Showing 1 changed file with 26 additions and 0 deletions.
26 changes: 26 additions & 0 deletions paddle/fluid/framework/paddle2cinn/cinn_zero_tensor_trick_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,32 @@ void CinnZeroTensorTrickPass::ApplyImpl(ir::Graph* graph) const {
"assign_value",
"gaussian_random",
"set_value"};
// NOTE: Hack squeeze2 0D-Tensor input
// If squeeze2 inputs 0D-Tensor and axes, The 0D-Tensor's shape will convert
// to 1D-Tensor, which could lead error. We hack squeeze2's axes attribute to
// resolve this. Change 0D-Tensor input to 1D-Tensor input and then make
// axes->axes[: -1]
for (const ir::Node* n : graph->Nodes()) {
if (n->IsOp() && n->Op()->Type() == "unsqueeze2") {
if (n->Op()->HasAttr("axes")) {
auto axes =
PADDLE_GET_CONST(std::vector<int32_t>, n->Op()->GetAttr("axes"));
for (const ir::Node* var : n->inputs) {
if (var->Var() &&
var->Var()->GetType() == proto::VarType::LOD_TENSOR) {
std::vector<int64_t> shape = var->Var()->GetShape();
if (shape.empty()) {
axes.pop_back();
n->Op()->SetAttr("axes", axes);
VLOG(4) << "unsqueeze2 axes dims is full, fix dim -> dim[:-1] to "
"avoid 0D-Tensor input error";
}
}
}
}
}
}

for (const ir::Node* n : graph->Nodes()) {
if (n->IsOp() && op_cases_fix_attr.count(n->Op()->Type())) {
if (n->Op()->HasAttr("shape")) {
Expand Down

0 comments on commit b846ec5

Please sign in to comment.