diff --git a/paddle/fluid/framework/details/CMakeLists.txt b/paddle/fluid/framework/details/CMakeLists.txt index 820846cacca6b..99ebd6a370b4a 100644 --- a/paddle/fluid/framework/details/CMakeLists.txt +++ b/paddle/fluid/framework/details/CMakeLists.txt @@ -382,6 +382,7 @@ set(IR_PASS_DEPS if(WITH_CINN) set(IR_PASS_DEPS ${IR_PASS_DEPS} build_cinn_pass) + set(IR_PASS_DEPS ${IR_PASS_DEPS} cinn_zero_tensor_trick_pass) endif() if(NOT APPLE diff --git a/paddle/fluid/framework/details/build_strategy.cc b/paddle/fluid/framework/details/build_strategy.cc index 91024a9dbe317..b0349966bb5d7 100644 --- a/paddle/fluid/framework/details/build_strategy.cc +++ b/paddle/fluid/framework/details/build_strategy.cc @@ -57,6 +57,9 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder { #ifdef PADDLE_WITH_CINN if (FLAGS_use_cinn || strategy.build_cinn_pass_) { + // Note: This is a trick to support 0D-Tensor for CINN. This pass will be + // removed in the near future. + AppendPass("cinn_zero_tensor_trick_pass"); // Note: This pass is used to enable cinn. AppendPass("build_cinn_pass"); AppendPrintGraphPass("graph_viz_pass", "_build_cinn_graph"); @@ -532,6 +535,7 @@ USE_PASS(fused_attention_pass); USE_PASS(fuse_adamw_op_pass); #endif #ifdef PADDLE_WITH_CINN +USE_PASS(cinn_zero_tensor_trick_pass); USE_PASS(build_cinn_pass); #endif #ifdef PADDLE_WITH_CUDA diff --git a/paddle/fluid/framework/paddle2cinn/CMakeLists.txt b/paddle/fluid/framework/paddle2cinn/CMakeLists.txt index a082dff6e54c2..f6a183304075b 100644 --- a/paddle/fluid/framework/paddle2cinn/CMakeLists.txt +++ b/paddle/fluid/framework/paddle2cinn/CMakeLists.txt @@ -8,6 +8,8 @@ pass_library( errors enforce) +pass_library(cinn_zero_tensor_trick_pass base) + cc_library( transform_desc SRCS transform_desc.cc @@ -62,6 +64,20 @@ if(WITH_TESTING) set_tests_properties(build_cinn_pass_test PROPERTIES LABELS "RUN_TYPE=CINN") target_link_libraries(build_cinn_pass_test ${PYTHON_LIBRARIES}) + cc_test_old( + cinn_zero_tensor_trick_pass_test + SRCS + cinn_zero_tensor_trick_pass_test.cc + DEPS + build_cinn_pass + cinn_compiler + op_registry + elementwise_add_op + generated_op) + set_tests_properties(cinn_zero_tensor_trick_pass_test + PROPERTIES LABELS "RUN_TYPE=CINN") + target_link_libraries(cinn_zero_tensor_trick_pass_test ${PYTHON_LIBRARIES}) + cc_test_old(transform_desc_test SRCS transform_desc_test.cc DEPS transform_desc) set_tests_properties(transform_desc_test PROPERTIES LABELS "RUN_TYPE=CINN") diff --git a/paddle/fluid/framework/paddle2cinn/cinn_zero_tensor_trick_pass.cc b/paddle/fluid/framework/paddle2cinn/cinn_zero_tensor_trick_pass.cc new file mode 100644 index 0000000000000..9c4e6192be424 --- /dev/null +++ b/paddle/fluid/framework/paddle2cinn/cinn_zero_tensor_trick_pass.cc @@ -0,0 +1,78 @@ +/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/framework/paddle2cinn/cinn_zero_tensor_trick_pass.h" + +#include +#include "glog/logging.h" + +#include "paddle/fluid/framework/op_proto_maker.h" +#include "paddle/fluid/framework/program_desc.h" + +namespace paddle { +namespace framework { +namespace paddle2cinn { + +void CinnZeroTensorTrickPass::ApplyImpl(ir::Graph* graph) const { + // fix shape attr of these ops + const std::unordered_set op_cases_fix_attr{"fill_constant", + "uniform_random", + "expand_v2", + "assign_value", + "gaussian_random", + "set_value"}; + for (const ir::Node* n : graph->Nodes()) { + if (n->IsOp() && op_cases_fix_attr.count(n->Op()->Type())) { + if (n->Op()->HasAttr("shape")) { + auto attr_type = n->Op()->GetAttrType("shape"); + if (attr_type == paddle::framework::proto::INTS) { + auto shapes = + PADDLE_GET_CONST(std::vector, n->Op()->GetAttr("shape")); + if (shapes.empty()) { + shapes.push_back(1); + n->Op()->SetAttr("shape", shapes); + VLOG(4) << "op " << n->Op()->Type() + << " shape attribute dims is empty, fix dim -> {1} "; + } + } else { /* attr_type == paddle::framework::proto::LONGS */ + auto shapes = + PADDLE_GET_CONST(std::vector, n->Op()->GetAttr("shape")); + if (shapes.empty()) { + shapes.push_back(1); + n->Op()->SetAttr("shape", shapes); + VLOG(4) << "op " << n->Op()->Type() + << " shape attribute dims is empty, fix dim -> {1} "; + } + } + } + } + if (n->IsVar()) { + if (n->Var() && n->Var()->GetType() == proto::VarType::LOD_TENSOR) { + std::vector shape = n->Var()->GetShape(); + if (shape.empty()) { + shape.push_back(1); + n->Var()->SetShape(shape); + VLOG(4) << "var " << n->Name() << " dims is empty, fix dim -> {1} "; + } + } + } + } +} + +} // namespace paddle2cinn +} // namespace framework +} // namespace paddle + +REGISTER_PASS(cinn_zero_tensor_trick_pass, + paddle::framework::paddle2cinn::CinnZeroTensorTrickPass); diff --git a/paddle/fluid/framework/paddle2cinn/cinn_zero_tensor_trick_pass.h b/paddle/fluid/framework/paddle2cinn/cinn_zero_tensor_trick_pass.h new file mode 100644 index 0000000000000..57e8beb6dacf8 --- /dev/null +++ b/paddle/fluid/framework/paddle2cinn/cinn_zero_tensor_trick_pass.h @@ -0,0 +1,33 @@ +/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include "paddle/fluid/framework/ir/graph.h" +#include "paddle/fluid/framework/ir/pass.h" + +namespace paddle { +namespace framework { +namespace paddle2cinn { + +class Graph; + +class CinnZeroTensorTrickPass : public framework::ir::Pass { + protected: + void ApplyImpl(ir::Graph* graph) const override; +}; + +} // namespace paddle2cinn +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/framework/paddle2cinn/cinn_zero_tensor_trick_pass_test.cc b/paddle/fluid/framework/paddle2cinn/cinn_zero_tensor_trick_pass_test.cc new file mode 100644 index 0000000000000..ff07ce2f3de50 --- /dev/null +++ b/paddle/fluid/framework/paddle2cinn/cinn_zero_tensor_trick_pass_test.cc @@ -0,0 +1,56 @@ +/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include + +#include "paddle/fluid/framework/paddle2cinn/cinn_zero_tensor_trick_pass.h" + +#include "paddle/fluid/framework/ir/pass_tester_helper.h" +#include "paddle/fluid/platform/enforce.h" + +namespace paddle { +namespace framework { +namespace paddle2cinn { + +TEST(CinnZeroTensorTrickPass, basic) { + ir::Layers layers; + auto* x = layers.data("x", {}); + auto* y = layers.data("y", {3, 4}); + auto* add_out_0 = layers.elementwise_add(x, y, nullptr, 0); + std::unique_ptr graph(new ir::Graph(layers.main_program())); + auto pass = ir::PassRegistry::Instance().Get("cinn_zero_tensor_trick_pass"); + VLOG(3) << DebugString(graph); + + graph.reset(pass->Apply(graph.release())); + VLOG(3) << DebugString(graph); + + for (auto* n : graph->Nodes()) { + if (n->IsVar()) { + if (n->Var() && n->Var()->GetType() == proto::VarType::LOD_TENSOR) { + std::vector shape = n->Var()->GetShape(); + PADDLE_ENFORCE_EQ( + shape.empty(), + false, + platform::errors::PreconditionNotMet( + "The shape of elementwise_add should not be empty after fuse")); + } + } + } +} + +} // namespace paddle2cinn +} // namespace framework +} // namespace paddle + +USE_PASS(cinn_zero_tensor_trick_pass);